aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests')
-rw-r--r--tensorflow/compiler/xla/tests/BUILD1436
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc1662
-rw-r--r--tensorflow/compiler/xla/tests/axpy_simple_test.cc90
-rw-r--r--tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc85
-rw-r--r--tensorflow/compiler/xla/tests/batch_normalization_test.cc210
-rw-r--r--tensorflow/compiler/xla/tests/binop_scaling_test.cc157
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_simple_test.cc179
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_test.cc286
-rw-r--r--tensorflow/compiler/xla/tests/build_defs.bzl149
-rw-r--r--tensorflow/compiler/xla/tests/call_test.cc115
-rw-r--r--tensorflow/compiler/xla/tests/check_execution_arity_test.cc138
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc263
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h409
-rw-r--r--tensorflow/compiler/xla/tests/client_test.cc127
-rw-r--r--tensorflow/compiler/xla/tests/codegen_test_base.cc90
-rw-r--r--tensorflow/compiler/xla/tests/codegen_test_base.h56
-rw-r--r--tensorflow/compiler/xla/tests/compilation_cache_test.cc218
-rw-r--r--tensorflow/compiler/xla/tests/compute_constant_test.cc249
-rw-r--r--tensorflow/compiler/xla/tests/concat_test.cc523
-rw-r--r--tensorflow/compiler/xla/tests/constants_test.cc193
-rw-r--r--tensorflow/compiler/xla/tests/convert_test.cc210
-rw-r--r--tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc117
-rw-r--r--tensorflow/compiler/xla/tests/convolution_test.cc361
-rw-r--r--tensorflow/compiler/xla/tests/convolution_variants_test.cc1294
-rw-r--r--tensorflow/compiler/xla/tests/copy_test.cc277
-rw-r--r--tensorflow/compiler/xla/tests/custom_call_test.cc148
-rw-r--r--tensorflow/compiler/xla/tests/deallocation_test.cc155
-rw-r--r--tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc215
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc387
-rw-r--r--tensorflow/compiler/xla/tests/dynamic_ops_test.cc506
-rw-r--r--tensorflow/compiler/xla/tests/floor_ceil_test.cc128
-rw-r--r--tensorflow/compiler/xla/tests/fmax_test.cc61
-rw-r--r--tensorflow/compiler/xla/tests/fusion_test.cc589
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc204
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h107
-rw-r--r--tensorflow/compiler/xla/tests/inprocess_service_test.cc204
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.cc566
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.h274
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util_test.cc102
-rw-r--r--tensorflow/compiler/xla/tests/local_client_aot_test.cc55
-rw-r--r--tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc111
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.cc220
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.h146
-rw-r--r--tensorflow/compiler/xla/tests/log_test.cc75
-rw-r--r--tensorflow/compiler/xla/tests/map_test.cc589
-rw-r--r--tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc179
-rw-r--r--tensorflow/compiler/xla/tests/multidimensional_slice_test.cc74
-rw-r--r--tensorflow/compiler/xla/tests/pad_test.cc420
-rw-r--r--tensorflow/compiler/xla/tests/params_test.cc357
-rw-r--r--tensorflow/compiler/xla/tests/pred_test.cc115
-rw-r--r--tensorflow/compiler/xla/tests/prng_test.cc238
-rw-r--r--tensorflow/compiler/xla/tests/query_inferred_shape_test.cc61
-rw-r--r--tensorflow/compiler/xla/tests/reduce_test.cc506
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc445
-rw-r--r--tensorflow/compiler/xla/tests/replay_test.cc168
-rw-r--r--tensorflow/compiler/xla/tests/reshape_motion_test.cc77
-rw-r--r--tensorflow/compiler/xla/tests/reshape_test.cc811
-rw-r--r--tensorflow/compiler/xla/tests/reverse_test.cc173
-rw-r--r--tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc160
-rw-r--r--tensorflow/compiler/xla/tests/round_trip_transfer_test.cc164
-rw-r--r--tensorflow/compiler/xla/tests/scalar_computations_test.cc630
-rw-r--r--tensorflow/compiler/xla/tests/select_and_scatter_test.cc395
-rw-r--r--tensorflow/compiler/xla/tests/select_test.cc276
-rw-r--r--tensorflow/compiler/xla/tests/set_return_value_test.cc116
-rw-r--r--tensorflow/compiler/xla/tests/slice_test.cc277
-rw-r--r--tensorflow/compiler/xla/tests/test_macros.h76
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.h115
-rw-r--r--tensorflow/compiler/xla/tests/transpose_test.cc203
-rw-r--r--tensorflow/compiler/xla/tests/tuple_test.cc415
-rw-r--r--tensorflow/compiler/xla/tests/unary_op_test.cc179
-rw-r--r--tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc235
-rw-r--r--tensorflow/compiler/xla/tests/vector_ops_simple_test.cc423
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc395
73 files changed, 21419 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
new file mode 100644
index 0000000000..93fe1fee4a
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -0,0 +1,1436 @@
+# Description:
+# Base testing infrastructure for XLA.
+
+licenses(["notice"]) # Apache 2.0
+
+package(
+ default_visibility = [":friends"],
+ features = ["no_layering_check"],
+)
+
+package_group(
+ name = "friends",
+ includes = [
+ "//tensorflow/compiler/xla:friends",
+ ],
+)
+
+# Filegroup used to collect source files for dependency checking.
+filegroup(
+ name = "c_srcs",
+ data = glob([
+ "**/*.cc",
+ "**/*.h",
+ ]),
+)
+
+load("//tensorflow/compiler/xla:xla.bzl", "export_dynamic_linkopts")
+load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
+load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites")
+load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_test_macros")
+
+# Generate test_suites for all backends, named "${backend}_tests".
+generate_backend_suites()
+
+cc_library(
+ name = "test_macros_header",
+ testonly = True,
+ hdrs = ["test_macros.h"],
+ deps = [
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:test",
+ ],
+)
+
+# Generate a test_macros_${BACKEND} library per backend with the proper copts.
+generate_backend_test_macros()
+
+cc_library(
+ name = "test_utils",
+ testonly = True,
+ hdrs = ["test_utils.h"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "literal_test_util",
+ testonly = True,
+ srcs = ["literal_test_util.cc"],
+ hdrs = ["literal_test_util.h"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+cc_library(
+ name = "hlo_test_base",
+ testonly = True,
+ srcs = ["hlo_test_base.cc"],
+ hdrs = ["hlo_test_base.h"],
+ deps = [
+ ":literal_test_util",
+ "//tensorflow/compiler/xla:shape_layout",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/legacy_flags:hlo_test_base_flags",
+ "//tensorflow/compiler/xla/service",
+ "//tensorflow/compiler/xla/service:backend",
+ "//tensorflow/compiler/xla/service:compiler",
+ "//tensorflow/compiler/xla/service:computation_layout",
+ "//tensorflow/compiler/xla/service:executable",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_execution_profile",
+ "//tensorflow/compiler/xla/service:hlo_graph_dumper",
+ "//tensorflow/compiler/xla/service:hlo_module_config",
+ "//tensorflow/compiler/xla/service:transfer_manager",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/core:test",
+ "//third_party/eigen3",
+ ],
+)
+
+cc_binary(
+ name = "local_client_aot_test_helper",
+ srcs = ["local_client_aot_test_helper.cc"],
+ deps = [
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/core:lib",
+ ],
+)
+
+genrule(
+ name = "local_client_aot_test_computation",
+ outs = ["local_client_aot_test_computation.o"],
+ cmd = "$(location :local_client_aot_test_helper) $(TARGET_CPU) > $(OUTS)",
+ local = 1,
+ tools = [":local_client_aot_test_helper"],
+)
+
+cc_library(
+ name = "client_library_test_base",
+ testonly = True,
+ srcs = ["client_library_test_base.cc"],
+ hdrs = ["client_library_test_base.h"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:hlo_pass_pipeline_flags",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/core:test",
+ ],
+)
+
+cc_library(
+ name = "codegen_test_base",
+ testonly = True,
+ srcs = ["codegen_test_base.cc"],
+ hdrs = ["codegen_test_base.h"],
+ data = [
+ "@llvm//:FileCheck",
+ ],
+ deps = [
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/service:backend",
+ "//tensorflow/compiler/xla/service:compiler",
+ "//tensorflow/compiler/xla/service:executable",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_module_config",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+cc_library(
+ name = "local_client_test_base",
+ testonly = True,
+ srcs = ["local_client_test_base.cc"],
+ hdrs = ["local_client_test_base.h"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/service:device_memory_allocator",
+ "//tensorflow/compiler/xla/service:local_service",
+ "//tensorflow/compiler/xla/service:platform_util",
+ "//tensorflow/compiler/xla/service:shaped_buffer",
+ "//tensorflow/compiler/xla/service:transfer_manager",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
+xla_test(
+ name = "bad_rng_shape_validation_test",
+ srcs = ["bad_rng_shape_validation_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "check_execution_arity_test",
+ srcs = ["check_execution_arity_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "query_inferred_shape_test",
+ srcs = ["query_inferred_shape_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "while_test",
+ srcs = ["while_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "axpy_simple_test",
+ srcs = ["axpy_simple_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "map_test",
+ srcs = ["map_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla:xla_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "params_test",
+ srcs = ["params_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "pred_test",
+ srcs = ["pred_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "select_test",
+ srcs = ["select_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "unary_op_test",
+ srcs = ["unary_op_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "scalar_computations_test",
+ srcs = ["scalar_computations_test.cc"],
+ shard_count = 16,
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/legacy_flags:llvm_backend_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "deallocation_test",
+ srcs = ["deallocation_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "deconstruct_tuple_test",
+ srcs = ["deconstruct_tuple_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "array_elementwise_ops_test",
+ srcs = ["array_elementwise_ops_test.cc"],
+ # This test includes comparisons to NAN, so disable fast-math.
+ backend_args = {
+ "cpu": ["--xla_fast_math=false"],
+ "cpu_parallel": ["--xla_fast_math=false"],
+ "gpu": ["--xla_fast_math=false"],
+ },
+ shard_count = 25,
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/legacy_flags:llvm_backend_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "dot_operation_test",
+ srcs = ["dot_operation_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
+ "//tensorflow/compiler/xla/legacy_flags:layout_util_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+# Tests the dot operation in some cases that can be performed via a
+# runtime call on some backends - e.g. a runtime call to to Eigen.
+xla_test(
+ name = "dot_operation_runtime_test",
+ srcs = ["dot_operation_test.cc"],
+ backend_args = {
+ "cpu": ["--xla_cpu_use_eigen"],
+ "cpu_parallel": ["--xla_cpu_use_eigen"],
+ },
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
+ "//tensorflow/compiler/xla/legacy_flags:layout_util_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+# Repeat dot_operation_runtime_test with single-threded eigen.
+xla_test(
+ name = "dot_operation_single_threaded_runtime_test",
+ srcs = ["dot_operation_test.cc"],
+ backend_args = {
+ "cpu": [
+ "--xla_cpu_use_eigen",
+ "--xla_cpu_multi_thread_eigen=false",
+ ],
+ "cpu_parallel": [
+ "--xla_cpu_use_eigen",
+ "--xla_cpu_multi_thread_eigen=false",
+ ],
+ },
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
+ "//tensorflow/compiler/xla/legacy_flags:layout_util_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "dot_operation_rowmajor_runtime_test",
+ srcs = ["dot_operation_test.cc"],
+ backend_args = {
+ "cpu": [
+ "--xla_cpu_use_eigen",
+ "--xla_default_layout=major2minor",
+ ],
+ "cpu_parallel": [
+ "--xla_cpu_use_eigen",
+ "--xla_default_layout=major2minor",
+ ],
+ },
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
+ "//tensorflow/compiler/xla/legacy_flags:layout_util_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "transpose_test",
+ srcs = ["transpose_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "constants_test",
+ srcs = ["constants_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "convolution_test",
+ timeout = "long",
+ srcs = ["convolution_test.cc"],
+ shard_count = 25,
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "convolution_variants_test",
+ timeout = "long",
+ srcs = ["convolution_variants_test.cc"],
+ backend_tags = {
+ # TODO(b/31436974): Fix msan failure. Failed on 2016-09-12.
+ "cpu": ["nomsan"],
+ "cpu_parallel": ["nomsan"],
+ },
+ shard_count = 30,
+ deps = [
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "convolution_dimension_numbers_test",
+ timeout = "long",
+ srcs = ["convolution_dimension_numbers_test.cc"],
+ shard_count = 20,
+ deps = [
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "batch_normalization_test",
+ srcs = ["batch_normalization_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "slice_test",
+ srcs = ["slice_test.cc"],
+ shard_count = 40,
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "multidimensional_slice_test",
+ srcs = ["multidimensional_slice_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "dynamic_ops_test",
+ srcs = ["dynamic_ops_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/service:device_memory_allocator",
+ "//tensorflow/compiler/xla/service:local_service",
+ "//tensorflow/compiler/xla/service:platform_util",
+ "//tensorflow/compiler/xla/service:shaped_buffer",
+ "//tensorflow/compiler/xla/service:transfer_manager",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "tuple_test",
+ srcs = ["tuple_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "vector_ops_reduce_test",
+ srcs = ["vector_ops_reduce_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "reduce_test",
+ srcs = ["reduce_test.cc"],
+ shard_count = 40,
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "reduce_window_test",
+ timeout = "long",
+ srcs = ["reduce_window_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "select_and_scatter_test",
+ timeout = "long",
+ srcs = ["select_and_scatter_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "copy_test",
+ srcs = ["copy_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "call_test",
+ srcs = ["call_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "custom_call_test",
+ srcs = ["custom_call_test.cc"],
+ linkopts = export_dynamic_linkopts,
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "binop_scaling_test",
+ srcs = ["binop_scaling_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "broadcast_simple_test",
+ srcs = ["broadcast_simple_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "pad_test",
+ srcs = ["pad_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "fmax_test",
+ srcs = ["fmax_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "log_test",
+ srcs = ["log_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "matrix_ops_simple_test",
+ srcs = ["matrix_ops_simple_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "prng_test",
+ srcs = ["prng_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "reshape_test",
+ srcs = ["reshape_test.cc"],
+ shard_count = 30,
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "reverse_test",
+ srcs = ["reverse_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "vector_ops_simple_test",
+ srcs = ["vector_ops_simple_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "concat_test",
+ srcs = ["concat_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "convert_test",
+ srcs = ["convert_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "compilation_cache_test",
+ srcs = ["compilation_cache_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla:xla_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "floor_ceil_test",
+ srcs = ["floor_ceil_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "compute_constant_test",
+ srcs = ["compute_constant_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "client_test",
+ srcs = ["client_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "inprocess_service_test",
+ srcs = ["inprocess_service_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "replay_test",
+ srcs = ["replay_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:protobuf_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/service:session_proto",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "broadcast_test",
+ srcs = ["broadcast_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "round_trip_packed_literal_test",
+ srcs = ["round_trip_packed_literal_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:packed_literal_reader",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "fusion_test",
+ srcs = ["fusion_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+cc_test(
+ name = "local_client_aot_test",
+ srcs = [
+ "local_client_aot_test.cc",
+ ":local_client_aot_test_computation.o",
+ ],
+ deps = [
+ "//tensorflow/compiler/xla:executable_run_options",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+xla_test(
+ name = "round_trip_transfer_test",
+ srcs = ["round_trip_transfer_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "set_return_value_test",
+ srcs = ["set_return_value_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "reshape_motion_test",
+ srcs = ["reshape_motion_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+cc_test(
+ name = "literal_test_util_test",
+ srcs = ["literal_test_util_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+# -----------------------------------------------------------------------------
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
new file mode 100644
index 0000000000..cf6f9a825c
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -0,0 +1,1662 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cmath>
+#include <limits>
+#include <memory>
+#include <numeric>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ArrayElementwiseOpTest : public ClientLibraryTestBase {
+ public:
+ ErrorSpec error_spec_{0.0001};
+};
+
+class ArrayElementwiseOpTestParamCount
+ : public ArrayElementwiseOpTest,
+ public ::testing::WithParamInterface<int> {};
+
+XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({});
+ auto result = builder.Neg(a);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, NegConstantF32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
+ auto result = builder.Neg(a);
+
+ ComputeAndCompareR1<float>(&builder, {2.5f, -3.14f, -2.25f, 10.0f, -6.0f}, {},
+ error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, NegConstantS32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>({-1, 0, 1, 324,
+ std::numeric_limits<int32>::min(),
+ std::numeric_limits<int32>::max()});
+ auto result = builder.Neg(a);
+
+ // -min == min for int32 due to an overflow. In C++ it is undefined behavior
+ // to do this calculation. For XLA we have not specified that, so it
+ // ought to work.
+ ComputeAndCompareR1<int32>(&builder,
+ {1, 0, -1, -324, std::numeric_limits<int32>::min(),
+ -std::numeric_limits<int32>::max()},
+ {});
+}
+
+TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
+ auto b = builder.ConstantR1<float>({100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
+ auto add = builder.Add(a, b);
+
+ ComputeAndCompareR1<float>(&builder, {97.5f, 6.27f, 5.0f, 0.5f, -993.0f}, {},
+ error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({});
+ auto b = builder.ConstantR1<float>({});
+ auto add = builder.Add(a, b);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
+ const int count = GetParam();
+ ComputationBuilder builder(client_, TestName());
+ std::vector<float> a_values;
+ std::vector<float> b_values;
+ for (int i = 0; i < count; ++i) {
+ a_values.push_back(i / static_cast<float>(count));
+ b_values.push_back(2 * i / static_cast<float>(count + 2));
+ }
+
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({a_values});
+ std::unique_ptr<GlobalData> a_data =
+ client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+ auto a_constant = builder.ConstantR1<float>(a_values);
+ auto a_param = builder.Parameter(0, a_literal->shape(), "a_param");
+
+ std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR1<float>({b_values});
+ std::unique_ptr<GlobalData> b_data =
+ client_->TransferToServer(*b_literal).ConsumeValueOrDie();
+ auto b_constant = builder.Parameter(1, a_literal->shape(), "b_param");
+ auto b_param = builder.ConstantR1<float>(b_values);
+
+ auto sum1 = builder.Add(a_constant, b_constant);
+ auto sum2 = builder.Add(a_constant, b_param);
+ auto sum3 = builder.Add(a_param, b_constant);
+ auto sum4 = builder.Add(a_param, b_param);
+
+ auto sum = builder.Add(sum1, sum2);
+ sum = builder.Add(sum, sum3);
+ sum = builder.Add(sum, sum4);
+
+ std::vector<float> expected;
+ for (int64 i = 0; i < count; ++i) {
+ expected.push_back(4 * (a_values[i] + b_values[i]));
+ }
+
+ ComputeAndCompareR1<float>(&builder, expected, {a_data.get(), b_data.get()},
+ error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
+ auto b = builder.ConstantR1<float>({100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
+ auto add = builder.Sub(a, b);
+
+ ComputeAndCompareR1<float>(&builder, {-102.5f, 0.01f, -0.5f, -20.5f, 1005.0f},
+ {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({});
+ auto b = builder.ConstantR1<float>({});
+ auto add = builder.Sub(a, b);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>({-1, 0, 2, 1000000000});
+ auto b = builder.ConstantR1<int32>({-1, 2, 1, -1});
+ auto add = builder.Sub(a, b);
+
+ ComputeAndCompareR1<int32>(&builder, {0, -2, 1, 1000000001}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>({});
+ auto b = builder.ConstantR1<int32>({});
+ auto add = builder.Sub(a, b);
+
+ ComputeAndCompareR1<int32>(&builder, {}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
+ auto b = builder.ConstantR1<float>({10.0f, 5.1f, 1.0f, 10.0f, -6.0f});
+ auto add = builder.Div(a, b);
+
+ ComputeAndCompareR1<float>(&builder, {-0.25f, 5.0f, 2.25f, -1.0f, -1.0f}, {},
+ error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({});
+ auto b = builder.ConstantR1<float>({});
+ auto add = builder.Div(a, b);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>(
+ {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f});
+ auto b = builder.ConstantR1<float>(
+ {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f});
+ auto add = builder.Rem(a, b);
+
+ ComputeAndCompareR1<float>(
+ &builder, {-2.5f, 0.0f, 0.25f, 0.0f, -0.0f, 1.0f, 1.0f, -1.0f, -0.0f}, {},
+ error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, RemZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({});
+ auto b = builder.ConstantR1<float>({});
+ auto add = builder.Rem(a, b);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<double>(
+ {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0});
+ auto b = builder.ConstantR1<double>(
+ {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0});
+ auto add = builder.Rem(a, b);
+
+ ComputeAndCompareR1<double>(
+ &builder, {-2.5, 0.0, 0.25, 0.0, -0.0, 1.0, 1.0, -1.0, -0.0}, {},
+ error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
+ auto b = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
+ auto add = builder.Mul(a, b);
+
+ ComputeAndCompareR1<float>(&builder, {-25.0f, 127.5f, 2.25f, -100.0f, -36.0f},
+ {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({});
+ auto b = builder.ConstantR1<float>({});
+ auto add = builder.Mul(a, b);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) {
+ std::vector<int32> data = {0,
+ 1,
+ -1,
+ 1234,
+ 0x1a243514,
+ std::numeric_limits<int32>::max(),
+ std::numeric_limits<int32>::min()};
+ // Form the test data set using all products of 'data' with itself.
+ std::vector<int32> a_data, b_data, expected;
+ for (int32 a : data) {
+ for (int32 b : data) {
+ a_data.push_back(a);
+ b_data.push_back(b);
+ expected.push_back(static_cast<uint32>(a) * static_cast<uint32>(b));
+ }
+ }
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>(a_data);
+ auto b = builder.ConstantR1<int32>(b_data);
+ auto add = builder.Mul(a, b);
+
+ ComputeAndCompareR1<int32>(&builder, expected, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>({});
+ auto b = builder.ConstantR1<int32>({});
+ auto add = builder.Mul(a, b);
+
+ ComputeAndCompareR1<int32>(&builder, {}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) {
+ std::vector<uint32> data = {0, 1, 0xDEADBEEF, 1234,
+ 0x1a243514, 0xFFFFFFFF, 0x80808080};
+
+ // Form the test data set using all products of 'data' with itself.
+ std::vector<uint32> a_data, b_data, expected;
+ for (uint32 a : data) {
+ for (uint32 b : data) {
+ a_data.push_back(a);
+ b_data.push_back(b);
+ expected.push_back(a * b);
+ }
+ }
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<uint32>(a_data);
+ auto b = builder.ConstantR1<uint32>(b_data);
+ auto add = builder.Mul(a, b);
+
+ ComputeAndCompareR1<uint32>(&builder, expected, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, LogicalAnd) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<bool>({false, false, true, true});
+ auto b = builder.ConstantR1<bool>({false, true, false, true});
+ auto out = builder.LogicalAnd(a, b);
+
+ ComputeAndCompareR1<bool>(&builder, {false, false, false, true}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, LogicalAndZeroElement) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<bool>({});
+ auto b = builder.ConstantR1<bool>({});
+ auto out = builder.LogicalAnd(a, b);
+
+ ComputeAndCompareR1<bool>(&builder, {}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, LogicalOr) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<bool>({false, false, true, true});
+ auto b = builder.ConstantR1<bool>({false, true, false, true});
+ auto out = builder.LogicalOr(a, b);
+
+ ComputeAndCompareR1<bool>(&builder, {false, true, true, true}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, LogicalOrZeroElement) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<bool>({});
+ auto b = builder.ConstantR1<bool>({});
+ auto out = builder.LogicalOr(a, b);
+
+ ComputeAndCompareR1<bool>(&builder, {}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, LogicalNot) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<bool>({false, true, true, false});
+ auto out = builder.LogicalNot(a);
+
+ ComputeAndCompareR1<bool>(&builder, {true, false, false, true}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, LogicalNotZeroElement) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<bool>({});
+ auto out = builder.LogicalNot(a);
+
+ ComputeAndCompareR1<bool>(&builder, {}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareEqF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
+ auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 2.25f, 10.0f, NAN});
+ auto compare = builder.Eq(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<float>({});
+ auto rhs = builder.ConstantR1<float>({});
+ auto compare = builder.Eq(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(&builder, {}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareGeF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
+ auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN});
+ auto compare = builder.Ge(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareGtF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
+ auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN});
+ auto compare = builder.Gt(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareLeF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<float>({-2.5f, 5.0f, 2.25f, NAN, 6.0f});
+ auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN});
+ auto compare = builder.Le(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(&builder, {true, true, false, false, false}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareLtF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
+ auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN});
+ auto compare = builder.Lt(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(&builder, {true, false, false, false, false}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareEqS32s) {
+ const int32 min = std::numeric_limits<int32>::min();
+ const int32 max = std::numeric_limits<int32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
+ auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
+ auto compare = builder.Eq(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(
+ &builder, {true, false, false, false, true, false, false, false, true},
+ {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<int32>({});
+ auto rhs = builder.ConstantR1<int32>({});
+ auto compare = builder.Eq(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(&builder, {}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareNeS32s) {
+ const int32 min = std::numeric_limits<int32>::min();
+ const int32 max = std::numeric_limits<int32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
+ auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
+ auto compare = builder.Ne(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(
+ &builder, {false, true, true, true, false, true, true, true, false}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareGeS32s) {
+ const int32 min = std::numeric_limits<int32>::min();
+ const int32 max = std::numeric_limits<int32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
+ auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
+ auto compare = builder.Ge(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(
+ &builder, {true, false, false, true, true, false, true, true, true}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareGtS32s) {
+ const int32 min = std::numeric_limits<int32>::min();
+ const int32 max = std::numeric_limits<int32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
+ auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
+ auto compare = builder.Gt(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(
+ &builder, {false, false, false, true, false, false, true, true, false},
+ {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareLeS32s) {
+ const int32 min = std::numeric_limits<int32>::min();
+ const int32 max = std::numeric_limits<int32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
+ auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
+ auto compare = builder.Le(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(
+ &builder, {true, true, true, false, true, true, false, false, true}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareLtS32s) {
+ const int32 min = std::numeric_limits<int32>::min();
+ const int32 max = std::numeric_limits<int32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
+ auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
+ auto compare = builder.Lt(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(
+ &builder, {false, true, true, false, false, true, false, false, false},
+ {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareEqU32s) {
+ const uint32 max = std::numeric_limits<uint32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
+ auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
+ auto compare = builder.Eq(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(
+ &builder, {true, false, false, false, true, false, false, false, true},
+ {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareNeU32s) {
+ const uint32 max = std::numeric_limits<uint32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
+ auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
+ auto compare = builder.Ne(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(
+ &builder, {false, true, true, true, false, true, true, true, false}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareGeU32s) {
+ const uint32 max = std::numeric_limits<uint32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
+ auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
+ auto compare = builder.Ge(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(
+ &builder, {true, false, false, true, true, false, true, true, true}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareGtU32s) {
+ const uint32 max = std::numeric_limits<uint32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
+ auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
+ auto compare = builder.Gt(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(
+ &builder, {false, false, false, true, false, false, true, true, false},
+ {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareLeU32s) {
+ const uint32 max = std::numeric_limits<uint32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
+ auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
+ auto compare = builder.Le(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(
+ &builder, {true, true, true, false, true, true, false, false, true}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareLtU32s) {
+ const uint32 max = std::numeric_limits<uint32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
+ auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
+ auto compare = builder.Lt(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(
+ &builder, {false, true, true, false, false, true, false, false, false},
+ {});
+}
+
+TEST_F(ArrayElementwiseOpTest, PowF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<float>({4.0f, 2.0f, 2.0f, NAN, 6.0f});
+ auto rhs = builder.ConstantR1<float>({2.0f, -2.0f, 3.0f, 10.0f, NAN});
+ auto minimum = builder.Pow(lhs, rhs);
+
+ ComputeAndCompareR1<float>(&builder, {16.0f, 0.25f, 8.0f, NAN, NAN}, {},
+ error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<float>({});
+ auto rhs = builder.ConstantR1<float>({});
+ auto minimum = builder.Pow(lhs, rhs);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+// Some Pow cases that can be implemented more efficiently.
+TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
+ ComputationBuilder b(client_, TestName());
+
+ std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f};
+ std::vector<float> exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
+
+ std::unique_ptr<Literal> param_literal = LiteralUtil::CreateR1<float>(values);
+ std::unique_ptr<GlobalData> param_data =
+ client_->TransferToServer(*param_literal).ConsumeValueOrDie();
+
+ auto sum = b.ConstantR0<float>(0.0f);
+ auto param = b.Parameter(0, param_literal->shape(), "param");
+ for (float exponent : exponents) {
+ sum = b.Add(sum, b.Pow(param, b.ConstantR0<float>(exponent)));
+ }
+
+ std::vector<float> expected;
+ for (auto value : values) {
+ float sum = 0.0f;
+ for (float exponent : exponents) {
+ sum += std::pow(value, exponent);
+ }
+ expected.push_back(sum);
+ }
+
+ ComputeAndCompareR1<float>(&b, expected, {param_data.get()}, error_spec_);
+}
+
+TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) {
+ const int count = GetParam();
+ ComputationBuilder builder(client_, TestName());
+ std::vector<float> values;
+ for (int i = 0; i < count; ++i) {
+ values.push_back(i / static_cast<float>(count));
+ }
+ auto x = builder.ConstantR1<float>(values);
+ auto exp = builder.Pow(x, builder.ConstantR0<float>(2.0f));
+
+ std::vector<float> expected;
+ for (float value : values) {
+ expected.push_back(value * value);
+ }
+
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, SquareIn4D) {
+ ComputationBuilder builder(client_, TestName());
+ Array4D<float> values(2, 2, 2, 2);
+
+ std::vector<float> values_vector;
+ std::vector<float> expected_vector;
+ for (int i = 0; i < values.num_elements(); ++i) {
+ values_vector.push_back(static_cast<float>(i) / values.num_elements());
+ expected_vector.push_back(values_vector.back() * values_vector.back());
+ }
+ values.SetValues(values_vector);
+
+ Array4D<float> expected(2, 2, 2, 2, expected_vector);
+
+ auto x = builder.ConstantR4FromArray4D<float>(values);
+ auto exp = builder.Pow(x, builder.ConstantR0<float>(2.0f));
+
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) {
+ ComputationBuilder builder(client_, TestName());
+ Array4D<float> values(2, 2, 0, 2);
+ Array4D<float> expected(2, 2, 0, 2);
+
+ auto x = builder.ConstantR4FromArray4D<float>(values);
+ auto exp = builder.Pow(x, builder.ConstantR0<float>(2.0f));
+
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+// GPU backend emits nvvm intrinsic for fmin and fmax, whose semantics is NOT
+// such
+// * fmin(NaN, x) = x
+// * fmax(NaN, x) = x
+// so we only test NAN on CPU.
+//
+// TODO(b/28180546): Make this compile in a way that is consistent
+// among backends.
+TEST_F(ArrayElementwiseOpTest, MinF32s) {
+ ComputationBuilder builder(client_, TestName());
+#if !defined(XLA_TEST_BACKEND_CPU)
+ auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f});
+ auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f});
+#else
+ auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f, NAN, 6.0f});
+ auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f, 10.0f, NAN});
+#endif
+ auto minimum = builder.Min(lhs, rhs);
+
+ ComputeAndCompareR1<float>(&builder,
+#if !defined(XLA_TEST_BACKEND_CPU)
+ {1.0f, -5.0f, 1.0f},
+#else
+ {1.0f, -5.0f, 1.0f, 10.0f, 6.0f},
+#endif
+ {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, MinZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<float>({});
+ auto rhs = builder.ConstantR1<float>({});
+ auto minimum = builder.Min(lhs, rhs);
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+// TODO(b/28180546): Make this compile in a way that is consistent
+// among backends. See comment on MinF32s test above.
+XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) {
+ ComputationBuilder builder(client_, TestName());
+#if !defined(XLA_TEST_BACKEND_CPU)
+ auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25});
+ auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0});
+#else
+ auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25, NAN, 6.0});
+ auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0, 10.0, NAN});
+#endif
+ auto minimum = builder.Min(lhs, rhs);
+
+ ComputeAndCompareR1<double>(&builder,
+#if !defined(XLA_TEST_BACKEND_CPU)
+ {1.0, -5.0, 1.0},
+#else
+ {1.0, -5.0, 1.0, 10.0, 6.0},
+#endif
+ {}, error_spec_);
+}
+
+// TODO(b/28180546): Make this compile in a way that is consistent
+// among backends. See comment on MinF32s test above.
+TEST_F(ArrayElementwiseOpTest, MaxF32s) {
+ ComputationBuilder builder(client_, TestName());
+#if !defined(XLA_TEST_BACKEND_CPU)
+ auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f});
+ auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f});
+#else
+ auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f, NAN, 6.0f});
+ auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f, 10.0f, NAN});
+#endif
+ auto maximum = builder.Max(lhs, rhs);
+
+ ComputeAndCompareR1<float>(&builder,
+#if !defined(XLA_TEST_BACKEND_CPU)
+ {2.0f, 1.0f, 2.25f},
+#else
+ {2.0f, 1.0f, 2.25f, 10.0f, 6.0f},
+#endif
+ {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, MaxZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<float>({});
+ auto rhs = builder.ConstantR1<float>({});
+ auto minimum = builder.Max(lhs, rhs);
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+// TODO(b/28180546): Make this compile in a way that is consistent
+// among backends. See comment on MinF32s test above.
+XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) {
+ ComputationBuilder builder(client_, TestName());
+#if !defined(XLA_TEST_BACKEND_CPU)
+ auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25});
+ auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0});
+#else
+ auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25, NAN, 6.0});
+ auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0, 10.0, NAN});
+#endif
+ auto maximum = builder.Max(lhs, rhs);
+
+ ComputeAndCompareR1<double>(&builder,
+#if !defined(XLA_TEST_BACKEND_CPU)
+ {2.0, 1.0, 2.25},
+#else
+ {2.0, 1.0, 2.25, 10.0, 6.0},
+#endif
+ {}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, MaxS32s) {
+ const int32 min = std::numeric_limits<int32>::min();
+ const int32 max = std::numeric_limits<int32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<int32>(
+ {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max});
+ auto y = builder.ConstantR1<int32>(
+ {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min});
+ builder.Max(x, y);
+
+ std::vector<int32> expected = {min, max, 0, -1, 0, 0, 0,
+ 1, 1, 10, max, max, max};
+ ComputeAndCompareR1<int32>(&builder, expected, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, MinS32s) {
+ const int32 min = std::numeric_limits<int32>::min();
+ const int32 max = std::numeric_limits<int32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<int32>(
+ {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max});
+ auto y = builder.ConstantR1<int32>(
+ {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min});
+ builder.Min(x, y);
+
+ std::vector<int32> expected = {min, min, min, -10, -1, -1, 0,
+ 0, 0, 1, 0, max, min};
+ ComputeAndCompareR1<int32>(&builder, expected, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, MaxU32s) {
+ const uint32 max = std::numeric_limits<uint32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<uint32>({0, 0, 1, 1, 1, max, max, max});
+ auto y = builder.ConstantR1<uint32>({0, 1, 0, 1, 10, 0, 234234, max});
+ builder.Max(x, y);
+
+ std::vector<uint32> expected = {0, 1, 1, 1, 10, max, max, max};
+ ComputeAndCompareR1<uint32>(&builder, expected, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, MinU32s) {
+ const uint32 max = std::numeric_limits<uint32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<uint32>({0, 0, 1, 1, 1, max, max, max});
+ auto y = builder.ConstantR1<uint32>({0, 1, 0, 1, 10, 0, 234234, max});
+ builder.Min(x, y);
+
+ std::vector<uint32> expected = {0, 0, 0, 1, 1, 0, 234234, max};
+ ComputeAndCompareR1<uint32>(&builder, expected, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, MaxTenF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<float>(
+ {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0});
+ auto y = builder.ConstantR1<float>(
+ {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0});
+ builder.Max(x, y);
+
+ std::vector<float> expected = {-0.0, 1.0, 2.0, 3.0, 4.0,
+ 5.0, 6.0, 7.0, 8.0, 9.0};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto u = builder.ConstantR1<float>({3.5});
+ auto v = builder.ConstantR1<float>({});
+ builder.Max(u, v);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) {
+ for (int broadcast_dim : {0, 1}) {
+ ComputationBuilder builder(client_, TestName());
+ auto u = builder.ConstantR1<float>({3.5});
+ auto v = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
+ builder.Max(u, v, /*broadcast_dimensions=*/{broadcast_dim});
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 2), {}, error_spec_);
+ }
+}
+
+TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f});
+ auto m =
+ builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
+ builder.Max(v, m, /*broadcast_dimensions=*/{1});
+
+ Array2D<float> expected({{2.0f, 3.14f, 4.0f}, {2.25f, 3.0f, 4.0f}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<float>({});
+ auto m = builder.ConstantR2<float>({{}, {}});
+ builder.Max(v, m, /*broadcast_dimensions=*/{1});
+
+ Array2D<float> expected({{}, {}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto scalar = builder.ConstantR0<int32>(2);
+ Array3D<int32> a_3d({{{3, 9, -1}, {2, -10, 3}}, {{-2, 2, 8}, {12, 10, 4}}});
+ auto array = builder.ConstantR3FromArray3D<int32>(a_3d);
+ builder.Max(array, scalar, /*broadcast_dimensions=*/{});
+
+ Array3D<int32> expected({{{3, 9, 2}, {2, 2, 3}}, {{2, 2, 8}, {12, 10, 4}}});
+ ComputeAndCompareR3<int32>(&builder, expected, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto scalar = builder.ConstantR0<int32>(2);
+ Array3D<int32> a_3d(2, 0, 3);
+ auto array = builder.ConstantR3FromArray3D<int32>(a_3d);
+ builder.Max(array, scalar, /*broadcast_dimensions=*/{});
+
+ Array3D<int32> expected(2, 0, 3);
+ ComputeAndCompareR3<int32>(&builder, expected, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto m =
+ builder.ConstantR2<float>({{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}});
+ auto v = builder.ConstantR1<float>({-10.2f, 16.4f});
+ builder.Min(m, v, /*broadcast_dimensions=*/{0});
+
+ Array2D<float> expected({{-10.4f, -10.2f, -10.2f}, {0.1f, 16.4f, 16.1f}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantR2<float>({{}, {}});
+ auto v = builder.ConstantR1<float>({-10.2f, 16.4f});
+ builder.Min(m, v, /*broadcast_dimensions=*/{0});
+
+ Array2D<float> expected({{}, {}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto array2d =
+ builder.ConstantR2<float>({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}});
+ auto array4d = builder.ConstantR4FromArray4D<float>(
+ {{{{-12.1f, 32.3f, 6.2f}}, {{0.0f, 32.5f, 3.0f}}},
+ {{{-2.5f, 64.29f, 6.5f}}, {{-0.01f, 32.25f, 2.6f}}}});
+ builder.Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3});
+
+ Array4D<float> expected(
+ {{{{-12.2f, 32.3f, 6.1f}}, {{0.0f, 32.2f, 2.5f}}},
+ {{{-12.2f, 64.29f, 6.1f}}, {{-0.01f, 32.2f, 2.5f}}}});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto array2d =
+ builder.ConstantR2<float>({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}});
+ Array4D<float> arg(2, 2, 0, 3);
+ auto array4d = builder.ConstantR4FromArray4D<float>(arg);
+ builder.Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3});
+
+ Array4D<float> expected(2, 2, 0, 3);
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<int32>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
+ auto y = builder.ConstantR1<int32>({9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
+ builder.Min(x, y);
+
+ std::vector<int32> expected = {0, 1, 2, 3, 4, 4, 3, 2, 1, 0};
+ ComputeAndCompareR1<int32>(&builder, expected, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<int32>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
+ auto y = builder.ConstantR1<int32>({9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
+ builder.Max(x, y);
+
+ std::vector<int32> expected = {9, 8, 7, 6, 5, 5, 6, 7, 8, 9};
+ ComputeAndCompareR1<int32>(&builder, expected, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>({-3, 26, 2, -1, 1});
+ auto b = builder.ConstantR1<int32>({10, 5, 1, 10, -10});
+ auto add = builder.Rem(a, b);
+
+ ComputeAndCompareR1<int32>(&builder, {-3, 1, 0, -1, 1}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, NonNanClampF32) {
+ ComputationBuilder builder(client_, TestName());
+ auto minimum = builder.ConstantR1<float>({1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
+ auto argument = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 10.0f});
+ auto maximum = builder.ConstantR1<float>({3.0f, 0.5f, 25.5f, 5.0f, 123.0});
+ auto clamp = builder.Clamp(minimum, argument, maximum);
+
+ ComputeAndCompareR1<float>(&builder, {2.0f, 0.5f, 1.0f, 2.25f, 10.0f}, {},
+ error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) {
+ ComputationBuilder builder(client_, TestName());
+ auto minimum = builder.ConstantR0<float>(0.0f);
+ auto argument = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
+ auto maximum = builder.ConstantR0<float>(5.0f);
+ auto clamp = builder.Clamp(minimum, argument, maximum);
+
+ ComputeAndCompareR1<float>(&builder, {2.0f, 5.0f, 0.0f, 1.0f, 4.0f}, {},
+ error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) {
+ ComputationBuilder builder(client_, TestName());
+ auto min_scalar = builder.ConstantR0<float>(0.0f);
+ auto min_vector = builder.ConstantR1<float>({1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
+ auto arg_vector = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
+ auto arg_scalar = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
+ auto max_scalar = builder.ConstantR0<float>(3.0f);
+ auto max_vector = builder.ConstantR1<float>({3.0f, 0.5f, 25.5f, 5.0f, 123.0});
+ // Perform clamp with broadcasted scalar and vector.
+ auto clamp = builder.Add(
+ builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar),
+ builder.Clamp(min_scalar, arg_vector, max_vector)),
+ builder.Add(builder.Clamp(min_vector, arg_scalar, max_vector),
+ builder.Clamp(min_scalar, arg_scalar, max_vector)));
+
+ ComputeAndCompareR1<float>(&builder, {8.0f, 4.5f, 2.0f, 6.5f, 15.0f}, {},
+ error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ std::unique_ptr<Literal> param1_literal =
+ LiteralUtil::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f});
+ std::unique_ptr<GlobalData> param1_data =
+ client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+
+ auto p0 = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto p1 = builder.Parameter(1, param1_literal->shape(), "param1");
+ auto add = builder.Add(p0, p1);
+
+ ComputeAndCompareR1<float>(&builder, {8.3f, 4.5f, 6.7f, 11.1f},
+ {param0_data.get(), param1_data.get()},
+ error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ std::unique_ptr<Literal> param1_literal =
+ LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
+ std::unique_ptr<GlobalData> param1_data =
+ client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+
+ auto p0 = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto p1 = builder.Parameter(1, param1_literal->shape(), "param1");
+ auto add = builder.Add(p0, p1);
+
+ Array3D<float> expected(0, 7, 0);
+ ComputeAndCompareR3<float>(
+ &builder, expected, {param0_data.get(), param1_data.get()}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto a = builder.ConstantR1<float>({1.1f, 2.2f, 3.3f, 4.4f});
+ auto p = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto add = builder.Add(a, p);
+
+ ComputeAndCompareR1<float>(&builder, {2.2f, 4.4f, 6.6f, 9.9f},
+ {param0_data.get()}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, TanhF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f});
+ auto result = builder.Tanh(a);
+
+ ComputeAndCompareR1<float>(&builder, {-0.986614f, 0.996260f, 0.978026}, {},
+ error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) {
+ // a ------ (add) --------- (add)
+ // / /
+ // b -----/ /
+ // c---------------------/
+ ComputationBuilder builder(client_, TestName());
+
+ auto a = builder.ConstantR1<float>({1.1f, 2.2f, 3.3f, 4.4f});
+ auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f});
+ auto c = builder.ConstantR1<float>({-3.3f, -15.5f, -7.7f, -29.9f});
+
+ auto add = builder.Add(a, b);
+ auto add2 = builder.Add(add, c);
+
+ ComputeAndCompareR1<float>(&builder, {-0.1f, -10.1f, -0.1f, -20.1f}, {},
+ error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) {
+ // b ------ (add) --------- (add)
+ // / /
+ // c -----/ /
+ // a---------------------/
+ ComputationBuilder builder(client_, TestName());
+
+ auto a = builder.ConstantR1<float>({91.1f, 2.2f, 3.3f, 4.4f});
+ auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f});
+ auto c = builder.ConstantR1<float>({-3.3f, -15.5f, -7.7f, -29.9f});
+
+ auto add = builder.Add(b, c);
+ auto add2 = builder.Add(a, add);
+
+ ComputeAndCompareR1<float>(&builder, {89.9f, -10.1f, -0.1f, -20.1f}, {},
+ error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, AddWithNeg) {
+ // a ----- (neg) ----- (add)
+ // /
+ // b ----- (neg) ----/
+ ComputationBuilder builder(client_, TestName());
+
+ auto a = builder.ConstantR1<float>({91.1f, 2.2f, 3.3f, 4.4f});
+ auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f});
+
+ auto neg_a = builder.Neg(a);
+ auto neg_b = builder.Neg(b);
+ auto result = builder.Add(neg_a, neg_b);
+
+ ComputeAndCompareR1<float>(&builder, {-93.2f, -5.4f, -7.6f, -9.8f}, {},
+ error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) {
+ // a ------ (add) ------------\
+ // / \
+ // b -----/ (add)
+ // /
+ // c ------ (add) ------------/
+ // /
+ // d -----/
+ ComputationBuilder builder(client_, TestName());
+
+ auto a = builder.ConstantR1<float>({91.1f, 2.2f, 3.3f, 4.4f});
+ auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f});
+ auto c = builder.ConstantR1<float>({-3.3f, -15.5f, -7.7f, -29.9f});
+ auto d = builder.ConstantR1<float>({-19.0f, 10.0f, -40.0f, 20.2f});
+
+ auto add_ab = builder.Add(a, b);
+ auto add_cd = builder.Add(c, d);
+ auto add_all = builder.Add(add_ab, add_cd);
+
+ ComputeAndCompareR1<float>(&builder, {70.9f, -0.1f, -40.1f, 0.1f}, {},
+ error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a =
+ builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
+ auto b =
+ builder.ConstantR2<float>({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
+ auto add = builder.Add(a, b);
+
+ Array2D<float> expected_array(
+ {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}});
+ ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) {
+ // Add a scalar + matrix.
+ ComputationBuilder builder(client_, TestName());
+ auto a =
+ builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
+ auto scalar = builder.ConstantR0<float>(3.0f);
+ auto add = builder.Add(scalar, a);
+
+ Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}});
+ ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) {
+ // Add a matrix + scalar.
+ ComputationBuilder builder(client_, TestName());
+ auto a =
+ builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
+ auto scalar = builder.ConstantR0<float>(3.0f);
+ auto add = builder.Add(a, scalar);
+
+ Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}});
+ ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) {
+ // Test simple broadcasting of a R1F32 over R2F32. The vector's size matches
+ // only dim 0 of the matrix.
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<float>({20.0f, 40.0f, 60.0f});
+ // clang-format off
+ auto m = builder.ConstantR2<float>({
+ {-2.5f, 3.14f, 1.0f},
+ {2.25f, -10.0f, 3.33f}});
+ // clang-format on
+ auto add = builder.Add(v, m, /*broadcast_dimensions=*/{1});
+ Array2D<float> expected_array(
+ {{17.5f, 43.14f, 61.0f}, {22.25f, 30.0f, 63.33f}});
+ ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) {
+ // Test broadcasting in Eq comparison.
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<int32>({42, 73});
+ auto m = builder.ConstantR2<int32>({{42, 73}, {42, 52}});
+
+ // This test exercises both possible broadcast dimensions for a vector/matrix
+ // comparison.
+ auto cmp_dim_0 = builder.Eq(v, m, /*broadcast_dimensions=*/{1});
+ auto cmp_dim_1 = builder.Eq(v, m, /*broadcast_dimensions=*/{0});
+ auto result = builder.Tuple({cmp_dim_0, cmp_dim_1});
+
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}).get(),
+ LiteralUtil::CreateR2<bool>({{true, false}, {false, false}}).get()});
+ ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) {
+ // Test broadcasting in Ne comparison.
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<int32>({42, 73});
+ auto m = builder.ConstantR2<int32>({{42, 73}, {42, 52}});
+ auto cmp = builder.Ne(v, m, /*broadcast_dimensions=*/{1});
+
+ const string expected = R"(pred[2,2] {
+ { 00 },
+ { 01 },
+})";
+ EXPECT_EQ(expected, ExecuteToString(&builder, {}));
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) {
+ // Test broadcasting in Ge comparison.
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<int32>({1, 2, 3, 4});
+ auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}});
+ auto cmp = builder.Ge(v, m, /*broadcast_dimensions=*/{1});
+
+ const string expected = R"(pred[2,4] {
+ { 1100 },
+ { 0001 },
+})";
+ EXPECT_EQ(expected, ExecuteToString(&builder, {}));
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) {
+ // Test broadcasting in Gt comparison.
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<int32>({1, 2, 3, 4});
+ auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}});
+ auto cmp = builder.Gt(v, m, /*broadcast_dimensions=*/{1});
+
+ const string expected = R"(pred[2,4] {
+ { 0100 },
+ { 0000 },
+})";
+ EXPECT_EQ(expected, ExecuteToString(&builder, {}));
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) {
+ // Test broadcasting in Le comparison.
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<int32>({1, 2, 3, 4});
+ auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}});
+ auto cmp = builder.Le(v, m, /*broadcast_dimensions=*/{1});
+
+ const string expected = R"(pred[2,4] {
+ { 1011 },
+ { 1111 },
+})";
+ EXPECT_EQ(expected, ExecuteToString(&builder, {}));
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) {
+ // Test broadcasting in Lt comparison.
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<int32>({1, 2, 3, 4});
+ auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}});
+ auto cmp = builder.Lt(v, m, /*broadcast_dimensions=*/{1});
+
+ const string expected = R"(pred[2,4] {
+ { 0011 },
+ { 1110 },
+})";
+ EXPECT_EQ(expected, ExecuteToString(&builder, {}));
+}
+
+TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) {
+ // Test simple broadcasting of a R1F32 over R2F32 when the order of binary op
+ // arguments is reversed.
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantR2<float>({{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}});
+ auto v = builder.ConstantR1<float>({2.0f, 4.0f, 6.0f});
+ auto add = builder.Mul(m, v, /*broadcast_dimensions=*/{1});
+ Array2D<float> expected_array({{3.0f, 10.0f, 21.0f}, {9.0f, 22.0f, 39.0f}});
+ ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) {
+ // Tests broadcasting for arrays with degenerate (size == 1) dimensions.
+ ComputationBuilder builder(client_, TestName());
+ // m's shape in XLA notation is {3, 2}
+ // md's shape in XLA notation is {3, 1}
+ // The result has shape {3, 2}, where md is broadcast over m
+ auto m =
+ builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
+ auto md = builder.ConstantR2<float>({{10.0f, 20.0f, 30.0f}});
+ auto add = builder.Add(m, md);
+ Array2D<float> expected_array(
+ {{7.5f, 23.14f, 31.0f}, {12.25f, 10.0f, 33.33f}});
+ ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim0) {
+ // Tests broadcasting for arrays with degenerate (size == 1) dimensions.
+ ComputationBuilder builder(client_, TestName());
+ // m's shape in XLA notation is {3, 2}
+ // md's shape in XLA notation is {1, 2}
+ // The result has shape {3, 2}, where md is broadcast over m
+ auto m =
+ builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
+ auto md = builder.ConstantR2<float>({{10.0f}, {20.0f}});
+ auto add = builder.Add(m, md);
+ Array2D<float> expected_array(
+ {{7.5f, 13.14f, 11.0f}, {22.25f, 10.0f, 23.33f}});
+ ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) {
+ // Tests broadcasting for two degenerate arrays. This kind of broadcasting
+ // effectively creates an "outer product" operation.
+ // This is taken from the Numpy docs example at:
+ // http://docs.scipy.org/doc/numpy-1.10.1/user/basics.broadcasting.html
+ ComputationBuilder builder(client_, TestName());
+ // a's shape in XLA notation is {1, 4}
+ // b's shape in XLA notation is {3, 1}
+ // The result has shape {3, 4}.
+ auto a = builder.ConstantR2<float>({{0.0f}, {10.0f}, {20.0f}, {30.0f}});
+ auto b = builder.ConstantR2<float>({{1.0f, 2.0f, 3.0f}});
+ auto add = builder.Add(a, b);
+ Array2D<float> expected_array({{1.0f, 2.0f, 3.0f},
+ {11.0f, 12.0f, 13.0f},
+ {21.0f, 22.0f, 23.0f},
+ {31.0f, 32.0f, 33.0f}});
+ ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) {
+ // Add together a (2,2) array and a (2) array, using dimension 0 for
+ // broadcasting (though there are two ways to broadcast these shapes).
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<float>({20.0f, 40.0f});
+ auto m = builder.ConstantR2<float>({{10.0f, 50.0f}, {77.0f, 88.0f}});
+ auto add = builder.Add(v, m, /*broadcast_dimensions=*/{1});
+ Array2D<float> expected_array({{30.0f, 90.0f}, {97.0f, 128.0f}});
+ ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) {
+ // Add together a (2,2) array and a (2) array, using dimension 1 for
+ // broadcasting (though there are two ways to broadcast these shapes).
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<float>({20.0f, 40.0f});
+ auto m = builder.ConstantR2<float>({{10.0f, 50.0f}, {77.0f, 88.0f}});
+ auto add = builder.Add(v, m, /*broadcast_dimensions=*/{0});
+ Array2D<float> expected_array({{30.0f, 70.0f}, {117.0f, 128.0f}});
+ ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) {
+ // Binary add of two R3s together
+ ComputationBuilder builder(client_, TestName());
+ Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
+ {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}});
+ auto a = builder.ConstantR3FromArray3D<float>(a_3d);
+
+ Array3D<float> b_3d({{{2.0f, 4.0f}, {6.0f, 8.0f}, {10.0f, 12.0f}},
+ {{14.0f, 16.0f}, {18.0f, 20.0f}, {22.0f, 24.0f}}});
+ auto b = builder.ConstantR3FromArray3D<float>(b_3d);
+ auto add = builder.Add(a, b);
+
+ Array3D<float> expected_3d(
+ {{{3.0f, 6.0f}, {9.0f, 12.0f}, {15.0f, 18.0f}},
+ {{21.0f, 24.0f}, {27.0f, 30.0f}, {33.0f, 36.0f}}});
+ ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) {
+ // Add together a (2, 3, 2) array with a (2) array, using dimension 0 for
+ // broadcasting (though there are two ways to broadcast these shapes).
+ ComputationBuilder builder(client_, TestName());
+ // clang-format off
+ Array3D<float> a_3d({
+ {{1.0f, 2.0f},
+ {3.0f, 4.0f},
+ {5.0f, 6.0f}},
+ {{7.0f, 8.0f},
+ {9.0f, 10.0f},
+ {11.0f, 12.0f}},
+ });
+ // clang-format on
+ auto a = builder.ConstantR3FromArray3D<float>(a_3d);
+ auto v = builder.ConstantR1<float>({10.0f, 20.0f});
+ auto add = builder.Add(a, v, /*broadcast_dimensions=*/{2});
+
+ Array3D<float> expected_3d(
+ {{{11.0f, 22.0f}, {13.0f, 24.0f}, {15.0f, 26.0f}},
+ {{17.0f, 28.0f}, {19.0f, 30.0f}, {21.0f, 32.0f}}});
+ ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) {
+ // Add together a (2, 3, 2) array with a (2) array, using dimension 2 for
+ // broadcasting (though there are two ways to broadcast these shapes).
+ ComputationBuilder builder(client_, TestName());
+ // clang-format off
+ Array3D<float> a_3d({
+ {{1.0f, 2.0f},
+ {3.0f, 4.0f},
+ {5.0f, 6.0f}},
+ {{7.0f, 8.0f},
+ {9.0f, 10.0f},
+ {11.0f, 12.0f}},
+ });
+ // clang-format on
+ auto a = builder.ConstantR3FromArray3D<float>(a_3d);
+ auto v = builder.ConstantR1<float>({10.0f, 20.0f});
+ auto add = builder.Add(a, v, /*broadcast_dimensions=*/{0});
+
+ // clang-format off
+ Array3D<float> expected_3d({
+ {{11.0f, 12.0f},
+ {13.0f, 14.0f},
+ {15.0f, 16.0f}},
+ {{27.0f, 28.0f},
+ {29.0f, 30.0f},
+ {31.0f, 32.0f}},
+ });
+ // clang-format on
+ ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) {
+ // Add together a (2, 3, 2) array with a (3, 2) array, using dimensions {1,2}
+ // for broadcasting.
+ ComputationBuilder builder(client_, TestName());
+ // clang-format off
+ Array3D<float> a_3d({
+ {{1.0f, 2.0f},
+ {3.0f, 4.0f},
+ {5.0f, 6.0f}},
+ {{7.0f, 8.0f},
+ {9.0f, 10.0f},
+ {11.0f, 12.0f}},
+ });
+ auto a = builder.ConstantR3FromArray3D<float>(a_3d);
+ auto m = builder.ConstantR2<float>({
+ {10.0f, 20.0f, 30.0f},
+ {40.0f, 50.0f, 60.0f},
+ });
+ auto add = builder.Add(a, m, /*broadcast_dimensions=*/{0, 1});
+
+ Array3D<float> expected_3d({
+ {{11.0f, 12.0f},
+ {23.0f, 24.0f},
+ {35.0f, 36.0f}},
+ {{47.0f, 48.0f},
+ {59.0f, 60.0f},
+ {71.0f, 72.0f}},
+ });
+ // clang-format on
+ ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) {
+ // Comparison between two 3D arrays of compatible shapes:
+ // (2, 3, 2) and (2, 3, 1): expected to produce a (2, 3, 2) shape of PREDs.
+ ComputationBuilder builder(client_, TestName());
+ Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
+ {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}});
+ auto a = builder.ConstantR3FromArray3D<float>(a_3d);
+
+ Array3D<float> b_3d({{{7.0f, 1.0f}, {3.0f, 10.0f}, {15.0f, 6.0f}}});
+ auto b = builder.ConstantR3FromArray3D<float>(b_3d);
+
+ auto compare = builder.Gt(a, b);
+
+ Array3D<int> expected_3d(
+ {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}});
+ const string expected = R"(pred[2,3,2] {
+{ { 01 },
+ { 00 },
+ { 00 } },
+{ { 01 },
+ { 10 },
+ { 01 } }
+})";
+ EXPECT_EQ(expected, ExecuteToString(&builder, {}));
+}
+
+TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5));
+ std::unique_ptr<Array4D<float>> operand_b_4d(new Array4D<float>(2, 3, 4, 5));
+ std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5));
+ float value = 0.0;
+ for (int64 p = 0; p < 2; ++p) {
+ for (int64 z = 0; z < 3; ++z) {
+ for (int64 y = 0; y < 4; ++y) {
+ for (int64 x = 0; x < 5; ++x) {
+ (*operand_a_4d)(p, z, y, x) = value;
+ (*operand_b_4d)(p, z, y, x) = 2.0 * value;
+ (*expected_4d)(p, z, y, x) = 3.0 * value;
+ value += 0.1;
+ }
+ }
+ }
+ }
+
+ auto a = builder.ConstantR4FromArray4D<float>(*operand_a_4d);
+ auto b = builder.ConstantR4FromArray4D<float>(*operand_b_4d);
+ auto add = builder.Add(a, b);
+
+ ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5));
+ std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5));
+ std::vector<float> operand_b_1d(3);
+ std::iota(operand_b_1d.begin(), operand_b_1d.end(), 1.0);
+
+ float value = 0.0;
+ for (int64 p = 0; p < 2; ++p) {
+ for (int64 z = 0; z < 3; ++z) {
+ for (int64 y = 0; y < 4; ++y) {
+ for (int64 x = 0; x < 5; ++x) {
+ (*operand_a_4d)(p, z, y, x) = value;
+ (*expected_4d)(p, z, y, x) = value + operand_b_1d[z];
+ value += 0.1;
+ }
+ }
+ }
+ }
+
+ auto a = builder.ConstantR4FromArray4D<float>(*operand_a_4d);
+ auto b = builder.ConstantR1<float>(operand_b_1d);
+ auto add = builder.Add(a, b, {1});
+
+ ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, R4_32x64x2x2_Plus_R1_64) {
+ constexpr int d0 = 16;
+ constexpr int d1 = 16;
+ constexpr int d2 = 2;
+ constexpr int d3 = 2;
+ Array4D<float> r4(d0, d1, d2, d3);
+ r4.Fill(1.0);
+ std::vector<float> r1(d1);
+ std::iota(r1.begin(), r1.end(), 1.0);
+
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR4FromArray4D(r4);
+ *a_literal->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout({0, 1, 2, 3});
+ auto a = builder.ConstantLiteral(*a_literal);
+ auto b = builder.ConstantR1<float>(r1);
+ builder.Add(a, b, {1});
+
+ for (int i0 = 0; i0 < d0; ++i0) {
+ for (int i1 = 0; i1 < d1; ++i1) {
+ for (int i2 = 0; i2 < d2; ++i2) {
+ for (int i3 = 0; i3 < d3; ++i3) {
+ r4(i0, i1, i2, i3) += r1[i1];
+ }
+ }
+ }
+ }
+ ComputeAndCompareR4<float>(&builder, r4, {}, error_spec_);
+}
+
+// Show that we can't add two opaques.
+TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) {
+ ComputationBuilder builder(client_, TestName());
+ auto shape = ShapeUtil::MakeOpaqueShape();
+ auto x = builder.Parameter(0, shape, "x");
+ auto concatenated = builder.Add(x, x);
+ StatusOr<Computation> computation_status = builder.Build();
+ ASSERT_FALSE(computation_status.ok());
+ EXPECT_MATCH(computation_status.status().ToString(),
+ testing::ContainsRegex(
+ "Expected non-opaque argument for lhs of binary operation"));
+}
+
+// Regression test for b/31927799. "slice - y" is fused and requires implicit
+// broadcast.
+TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) {
+ ComputationBuilder builder(client_, TestName());
+ auto x_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
+ auto y_literal = LiteralUtil::CreateR1<float>({4, 5});
+ auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
+ auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
+
+ auto x = builder.Parameter(0, x_literal->shape(), "x");
+ auto y = builder.Parameter(1, y_literal->shape(), "y");
+ auto slice = builder.Slice(x, {1}, {2});
+ builder.Sub(slice, y);
+
+ ComputeAndCompareR1<float>(&builder, {-2, -3}, {x_data.get(), y_data.get()},
+ error_spec_);
+}
+
+INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount,
+ ArrayElementwiseOpTestParamCount,
+ ::testing::Values(127, 128, 129, 17 * 4096));
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::legacy_flags::AppendLlvmBackendFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc
new file mode 100644
index 0000000000..adffac09e3
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc
@@ -0,0 +1,90 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class AxpySimpleTest : public ClientLibraryTestBase {};
+
+TEST_F(AxpySimpleTest, AxTenValues) {
+ ComputationBuilder builder(client_, "ax_10");
+ auto alpha = builder.ConstantR0<float>(3.1415926535);
+ auto x = builder.ConstantR1<float>(
+ {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0});
+ auto ax = builder.Mul(alpha, x);
+
+ std::vector<float> expected = {
+ -3.14159265, 3.14159265, 6.28318531, -6.28318531, -9.42477796,
+ 9.42477796, 12.56637061, -12.56637061, -15.70796327, 15.70796327};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(AxpySimpleTest, AxpyZeroValues) {
+ ComputationBuilder builder(client_, "axpy_10");
+ auto alpha = builder.ConstantR0<float>(3.1415926535);
+ auto x = builder.ConstantR1<float>({});
+ auto y = builder.ConstantR1<float>({});
+ auto ax = builder.Mul(alpha, x);
+ auto axpy = builder.Add(ax, y);
+
+ std::vector<float> expected = {};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(AxpySimpleTest, AxpyTenValues) {
+ ComputationBuilder builder(client_, "axpy_10");
+ auto alpha = builder.ConstantR0<float>(3.1415926535);
+ auto x = builder.ConstantR1<float>(
+ {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0});
+ auto y = builder.ConstantR1<float>(
+ {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0});
+ auto ax = builder.Mul(alpha, x);
+ auto axpy = builder.Add(ax, y);
+
+ std::vector<float> expected = {
+ 1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796,
+ 6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc
new file mode 100644
index 0000000000..c7b533b80f
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc
@@ -0,0 +1,85 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests that passing a bad shape to RNG's output parameter causes a validation
+// failure rather than causing a crash.
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class BadRngShapeValidationTest : public ClientLibraryTestBase {};
+
+TEST_F(BadRngShapeValidationTest, DefaultConstructedShapeCreatesError) {
+ ComputationBuilder builder(client_, TestName());
+ auto zero = builder.ConstantR0<float>(0.0);
+ auto one = builder.ConstantR0<float>(1.0);
+ Shape default_constructed;
+ builder.RngUniform(zero, one, default_constructed);
+
+ StatusOr<Computation> computation = builder.Build();
+ EXPECT_FALSE(computation.ok());
+ LOG(INFO) << "status received: " << computation.status();
+ EXPECT_MATCH(computation.status().error_message(),
+ testing::HasSubstr("shape has invalid"));
+}
+
+TEST_F(BadRngShapeValidationTest, ShapeWithoutLayoutIsOk) {
+ ComputationBuilder builder(client_, TestName());
+ auto zero = builder.ConstantR0<float>(0.0);
+ auto one = builder.ConstantR0<float>(1.0);
+ Shape sans_layout;
+ sans_layout.set_element_type(F32);
+ sans_layout.add_dimensions(1);
+
+ builder.RngUniform(zero, one, sans_layout);
+
+ StatusOr<Computation> computation = builder.Build();
+ ASSERT_TRUE(computation.ok());
+ LOG(INFO) << computation.status();
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
new file mode 100644
index 0000000000..598fd69909
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
@@ -0,0 +1,210 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cmath>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class BatchNormalizationTest : public ClientLibraryTestBase {
+ protected:
+ BatchNormalizationTest() : input_array_(kSamples, kZ, kY, kX) {
+ Array2D<float> pz({
+ // z0 z1
+ {-1.0f, 4.1f}, // p0
+ {2.0f, 4.1f}, // p1
+ {5.0f, 4.4f}, // p2
+ });
+ input_array_.FillWithPZ(pz);
+ input_literal_ = *LiteralUtil::CreateR4FromArray4D(input_array_);
+ CHECK_EQ(kSamples, input_array_.planes());
+ CHECK_EQ(kZ, input_array_.depth());
+ CHECK_EQ(kY, input_array_.height());
+ CHECK_EQ(kY, input_array_.width());
+ }
+
+ static constexpr int64 kSamples = 3;
+ static constexpr int64 kX = 1;
+ static constexpr int64 kY = 1;
+ static constexpr int64 kZ = 2;
+
+ Array4D<float> input_array_;
+ Literal input_literal_;
+ const ErrorSpec error_spec_{0.001, 0.001};
+};
+
+TEST_F(BatchNormalizationTest, SubtractInZ) {
+ ComputationBuilder builder(client_, "subtract_in_z_one_sample");
+ auto x = builder.ConstantLiteral(input_literal_);
+ auto y = builder.ConstantR1<float>({3.14, 4.25});
+ builder.Sub(x, y, /*broadcast_dimensions=*/{1});
+
+ Array4D<float> expected(kSamples, kZ, kY, kX);
+ Array2D<float> pz({
+ {-1.0f - 3.14f, 4.1f - 4.25f}, // p0
+ {2.0f - 3.14f, 4.1f - 4.25f}, // p1
+ {5.0f - 3.14f, 4.4f - 4.25f}, // p2
+ });
+ expected.FillWithPZ(pz);
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(BatchNormalizationTest, SquareTesseractElementwise) {
+ ComputationBuilder builder(client_, "square_tesseract_elementwise");
+ auto x = builder.ConstantLiteral(input_literal_);
+ builder.SquareF32(x);
+
+ Array4D<float> expected(kSamples, kZ, kY, kX);
+ Array2D<float> expected_pz({
+ {std::pow(-1.0f, 2.0f), std::pow(4.1f, 2.0f)},
+ {std::pow(2.0f, 2.0f), std::pow(4.1f, 2.0f)},
+ {std::pow(5.0f, 2.0f), std::pow(4.4f, 2.0f)},
+ });
+ expected.FillWithPZ(expected_pz);
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(BatchNormalizationTest, SumToZ) {
+ ComputationBuilder builder(client_, "sum_to_z");
+ auto input_activations = builder.ConstantLiteral(input_literal_);
+ Computation add = CreateScalarAddComputation(F32, &builder);
+ // Reduce all but the Z dimension.
+ builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add,
+ {0, 2, 3});
+
+ std::vector<float> expected = {6, 12.6};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(BatchNormalizationTest, SquareAndReduce) {
+ ComputationBuilder builder(client_, "square_and_reduce");
+ auto input_activations = builder.ConstantLiteral(input_literal_);
+ auto set_means = builder.ConstantR1<float>({2.f, 4.2f});
+ auto activation_deviations = builder.Sub(input_activations, set_means,
+ /*broadcast_dimensions=*/{1});
+ Computation add = CreateScalarAddComputation(F32, &builder);
+ auto dev_squares = builder.SquareF32(activation_deviations);
+ auto sum_of_squares = builder.Reduce(
+ dev_squares, builder.ConstantR0<float>(0.0f), add, {0, 2, 3});
+
+ std::vector<float> expected = {18, 0.06};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(BatchNormalizationTest, VarianceToStddev) {
+ ComputationBuilder builder(client_, "variance_to_stddev");
+ auto variance = builder.ConstantR1<float>({6.f, .02f});
+ auto sqrt = builder.SqrtF32(variance);
+
+ std::vector<float> expected = {2.44948974f, 0.14142136f};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+// Compare against a forward batch normalization example in the NN spec
+// reference.
+TEST_F(BatchNormalizationTest, SpecComparisonForward) {
+ ComputationBuilder builder(client_, "batch_normalize_per_spec");
+ auto input_activations =
+ builder.CheckShape(builder.ConstantLiteral(input_literal_),
+ ShapeUtil::MakeShape(F32, {3, 2, 1, 1}));
+ auto gamma = builder.ConstantR1<float>({1.0, 1.0});
+ auto beta = builder.ConstantR1<float>({0.0, 0.0});
+ Computation add = CreateScalarAddComputation(F32, &builder);
+ // Reduce all dimensions except dimension 1.
+ Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2});
+ auto sum = builder.CheckShape(
+ builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add,
+ /*dimensions_to_reduce=*/{0, 2, 3}),
+ TwoElementVectorF32);
+ auto input_shape = builder.GetShape(input_activations).ConsumeValueOrDie();
+ auto sum_shape = builder.GetShape(sum).ConsumeValueOrDie();
+ auto count = builder.ConstantR0<float>(ShapeUtil::ElementsIn(*input_shape) /
+ ShapeUtil::ElementsIn(*sum_shape));
+ auto set_means = builder.Div(sum, count);
+
+ const float kEpsilon = 1e-9f;
+ auto epsilon = builder.ConstantR0<float>(kEpsilon);
+ auto epsilon2 = builder.ConstantR1<float>({kEpsilon, kEpsilon});
+ auto activation_deviations = builder.Sub(input_activations, set_means,
+ /*broadcast_dimensions=*/{1});
+ auto dev_squares = builder.SquareF32(activation_deviations);
+ auto sum_of_squares = builder.CheckShape(
+ builder.Reduce(dev_squares, builder.ConstantR0<float>(0.0f), add,
+ /*dimensions_to_reduce=*/{0, 2, 3}),
+ TwoElementVectorF32);
+ auto variance = builder.Div(sum_of_squares, count);
+ auto standard_deviation = builder.SqrtF32(variance);
+ auto standard_deviation_above_epsilon = builder.CheckShape(
+ builder.Gt(standard_deviation, epsilon), ShapeUtil::MakeShape(PRED, {2}));
+ auto gt_eps = builder.Select(standard_deviation_above_epsilon,
+ standard_deviation, epsilon2);
+ auto normalization_factors = builder.ReciprocalF32(gt_eps);
+ auto normalized_input_activations =
+ builder.Mul(activation_deviations, normalization_factors,
+ /*broadcast_dimensions=*/{1});
+ /* auto output_activations = */ builder.Add(
+ builder.Mul(normalized_input_activations, gamma,
+ /*broadcast_dimensions=*/{1}),
+ beta, /*broadcast_dimensions=*/{1});
+
+ Array4D<float> expected(kSamples, kZ, kY, kX);
+ Array2D<float> pz({
+ {-3.f / std::sqrt(6.f), -.1f / std::sqrt(.02f)},
+ {0.f, -.1f / std::sqrt(.02f)},
+ {3.f / std::sqrt(6.f), .2f / std::sqrt(.02f)},
+ });
+ expected.FillWithPZ(pz);
+
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/binop_scaling_test.cc b/tensorflow/compiler/xla/tests/binop_scaling_test.cc
new file mode 100644
index 0000000000..e825bd435b
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/binop_scaling_test.cc
@@ -0,0 +1,157 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class BinopScalingTest : public ClientLibraryTestBase {};
+
+TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixRowVector_32x4) {
+ auto alhs = MakeLinspaceArray2D(0.0, 1.0, 32, 4);
+ auto arhs = MakeLinspaceArray2D(0.0, 1.0, 1, 4);
+
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR2FromArray2D<float>(*alhs);
+ auto rhs = builder.ConstantR2FromArray2D<float>(*arhs);
+ builder.Add(lhs, rhs);
+
+ auto aexpected = ReferenceUtil::MapWithIndexArray2D(
+ *alhs, [&](float lhs_value, int64 row, int64 col) {
+ return lhs_value + (*arhs)(0, col);
+ });
+ ComputeAndCompareR2<float>(&builder, *aexpected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixRowVector_129x129) {
+ auto alhs = MakeLinspaceArray2D(0.0, 1.0, 129, 129);
+ auto arhs = MakeLinspaceArray2D(0.0, 1.0, 1, 129);
+
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR2FromArray2D<float>(*alhs);
+ auto rhs = builder.ConstantR2FromArray2D<float>(*arhs);
+ builder.Add(lhs, rhs);
+
+ auto aexpected = ReferenceUtil::MapWithIndexArray2D(
+ *alhs, [&](float lhs_value, int64 row, int64 col) {
+ return lhs_value + (*arhs)(0, col);
+ });
+ ComputeAndCompareR2<float>(&builder, *aexpected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_9x5) {
+ auto alhs = MakeLinspaceArray2D(0.0, 1.0, 9, 5);
+ auto arhs = MakeLinspaceArray2D(0.0, 1.0, 9, 1);
+
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR2FromArray2D<float>(*alhs);
+ auto rhs = builder.ConstantR2FromArray2D<float>(*arhs);
+ builder.Add(lhs, rhs);
+
+ auto aexpected = ReferenceUtil::MapWithIndexArray2D(
+ *alhs, [&](float lhs_value, int64 row, int64 col) {
+ return lhs_value + (*arhs)(row, 0);
+ });
+ ComputeAndCompareR2<float>(&builder, *aexpected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_129x257) {
+ auto alhs = MakeLinspaceArray2D(0.0, 1.0, 129, 257);
+ auto arhs = MakeLinspaceArray2D(0.0, 1.0, 129, 1);
+
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR2FromArray2D<float>(*alhs);
+ auto rhs = builder.ConstantR2FromArray2D<float>(*arhs);
+ builder.Add(lhs, rhs);
+
+ auto aexpected = ReferenceUtil::MapWithIndexArray2D(
+ *alhs, [&](float lhs_value, int64 row, int64 col) {
+ return lhs_value + (*arhs)(row, 0);
+ });
+ ComputeAndCompareR2<float>(&builder, *aexpected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(BinopScalingTest, R0PlusR2F32) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR0<float>(42.0);
+ auto rhs = builder.ConstantR2<float>({
+ {1.0, 2.0}, {3.0, 4.0},
+ });
+ builder.Add(lhs, rhs);
+
+ Array2D<float> expected(2, 2);
+ expected(0, 0) = 42.0 + 1.0;
+ expected(0, 1) = 42.0 + 2.0;
+ expected(1, 0) = 42.0 + 3.0;
+ expected(1, 1) = 42.0 + 4.0;
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(BinopScalingTest, R4PlusR0S32) {
+ ComputationBuilder builder(client_, TestName());
+ // clang-format off
+ Array4D<int> lhs_array({
+ {{{1, 2},
+ {3, 4},
+ {5, 6}}},
+ {{{7, 8},
+ {9, 10},
+ {11, 12}}},
+ });
+ Array4D<int> expected({
+ {{{43, 44},
+ {45, 46},
+ {47, 48}}},
+ {{{49, 50},
+ {51, 52},
+ {53, 54}}},
+ });
+ // clang-format on
+
+ auto lhs = builder.ConstantR4FromArray4D(lhs_array);
+ auto rhs = builder.ConstantR0<int>(42);
+ builder.Add(lhs, rhs);
+ ComputeAndCompareR4<int>(&builder, expected, {});
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
new file mode 100644
index 0000000000..200d4d4563
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
@@ -0,0 +1,179 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <numeric>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+using BroadcastSimpleTest = ClientLibraryTestBase;
+
+XLA_TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) {
+ ComputationBuilder b(client_, TestName());
+ b.Broadcast(b.ConstantR0<float>(1.5), {});
+ ComputeAndCompareR0<float>(&b, 1.5, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x3) {
+ ComputationBuilder b(client_, TestName());
+ b.Broadcast(b.ConstantR0<float>(2.25), {2, 3});
+ Array2D<float> expected(2, 3, 2.25);
+ ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x0) {
+ ComputationBuilder b(client_, TestName());
+ b.Broadcast(b.ConstantR0<float>(2.25), {2, 0});
+ Array2D<float> expected(2, 0);
+ ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_0x2) {
+ ComputationBuilder b(client_, TestName());
+ b.Broadcast(b.ConstantR0<float>(2.25), {0, 2});
+ Array2D<float> expected(0, 2);
+ ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) {
+ ComputationBuilder b(client_, TestName());
+ b.Broadcast(b.ConstantR1<float>({1, 2, 3}), {2});
+
+ Array2D<float> expected(2, 3);
+ expected(0, 0) = 1;
+ expected(0, 1) = 2;
+ expected(0, 2) = 3;
+ expected(1, 0) = 1;
+ expected(1, 1) = 2;
+ expected(1, 2) = 3;
+ ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) {
+ ComputationBuilder b(client_, TestName());
+ b.Broadcast(b.ConstantR1<float>({}), {2});
+
+ Array2D<float> expected(2, 0);
+ ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, 1DToZeroElement2D) {
+ ComputationBuilder b(client_, TestName());
+ b.Broadcast(b.ConstantR1<float>({1, 2, 3}), {0});
+
+ Array2D<float> expected(0, 3);
+ ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
+ // Verify that binary op and degenerate dimension broadcast work together in
+ // the same operation.
+ //
+ // The lhs shape [1, 2] is first broadcast up to [2, 1, 2] using in-dimension
+ // broadcasting (broadcast_dimensions {1, 2}), then is added to the rhs shape
+ // [2, 3, 1]. Degenerate dimension broadcasting then broadcasts the size one
+ // dimensions.
+ ComputationBuilder b(client_, TestName());
+
+ b.Add(b.ConstantR2<float>({{1.0, 5.0}}),
+ b.ConstantLiteral(*LiteralUtil::CreateR3<float>(
+ {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
+ /*broadcast_dimensions=*/{1, 2});
+
+ auto expected =
+ LiteralUtil::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}},
+ {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}});
+
+ ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
+ // Binary dimension broadcasting of the smaller lhs ([2, 2] up to [2, 2, 2])
+ // results in a shape incompatible with the lhs [2, 3, 1].
+ ComputationBuilder b(client_, TestName());
+
+ b.Add(b.ConstantR2<float>({{1.0, 5.0}, {1.0, 5.0}}),
+ b.ConstantLiteral(*LiteralUtil::CreateR3<float>(
+ {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
+ /*broadcast_dimensions=*/{1, 2});
+
+ auto result_status = Execute(&b, {});
+ EXPECT_FALSE(result_status.ok());
+ EXPECT_MATCH(result_status.status().error_message(),
+ testing::ContainsRegex("broadcast dimension 0 mismatch"));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) {
+ // Test invalid broadcasting with [1, 2] and [2, 3] inputs.
+ ComputationBuilder b(client_, TestName());
+
+ b.Add(b.ConstantR2<float>({{1.0, 2.0}}),
+ b.ConstantR2<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
+
+ auto result_status = Execute(&b, {});
+ EXPECT_FALSE(result_status.ok());
+ EXPECT_MATCH(
+ result_status.status().error_message(),
+ testing::ContainsRegex("binary op BINOP_ADD with incompatible shapes"));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) {
+ // Test invalid broadcasting with [1, 2] and [2, 3] inputs.
+ ComputationBuilder b(client_, TestName());
+
+ b.Add(b.ConstantR2<float>({{1.0, 2.0}}),
+ b.ConstantR2<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
+
+ auto result_status = Execute(&b, {});
+ EXPECT_FALSE(result_status.ok());
+ EXPECT_MATCH(
+ result_status.status().error_message(),
+ testing::ContainsRegex("binary op BINOP_ADD with incompatible shapes"));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc
new file mode 100644
index 0000000000..1796a732e5
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/broadcast_test.cc
@@ -0,0 +1,286 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <utility>
+
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class BroadcastTest : public HloTestBase {};
+
+XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
+ // Test degenerate case of broadcasting a scalar into a scalar.
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {}), input, {}));
+
+ // Create HLO module, compile, and execute.
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ hlo_module->AddEntryComputation(builder.Build());
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+
+ LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR0<float>(42.0), *result,
+ error_spec_);
+}
+
+XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {2, 2}), input, {}));
+
+ // Create HLO module, compile, and execute.
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ hlo_module->AddEntryComputation(builder.Build());
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+
+ LiteralTestUtil::ExpectNear(
+ *LiteralUtil::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), *result,
+ error_spec_);
+}
+
+XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
+
+ // Broadcast vector in both dimension 0 and dimension 1. Join them in a tuple
+ // to enable testing of the results.
+ auto element1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {3, 2}), input, {0}));
+ auto element2 = builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {2, 3}), input, {1}));
+ builder.AddInstruction(HloInstruction::CreateTuple({element1, element2}));
+
+ // Create HLO module, compile, and execute.
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ hlo_module->AddEntryComputation(builder.Build());
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+
+ LiteralTestUtil::ExpectNear(
+ *LiteralUtil::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
+ result->tuple_literals(0), error_spec_);
+
+ LiteralTestUtil::ExpectNear(
+ *LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
+ result->tuple_literals(1), error_spec_);
+}
+
+XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1}));
+
+ // Create HLO module, compile, and execute.
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ hlo_module->AddEntryComputation(builder.Build());
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+
+ LiteralTestUtil::ExpectNear(
+ *LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), *result,
+ error_spec_);
+}
+
+XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
+ // Degenerately broadcasting a shape into a shape of the same rank reorders
+ // the dimensions, ie transpose.
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0}));
+
+ // Create HLO module, compile, and execute.
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ hlo_module->AddEntryComputation(builder.Build());
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+
+ LiteralTestUtil::ExpectNear(
+ *LiteralUtil::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), *result,
+ error_spec_);
+}
+
+XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2}));
+
+ // Create HLO module, compile, and execute.
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ hlo_module->AddEntryComputation(builder.Build());
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+
+ LiteralTestUtil::ExpectNear(
+ *LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
+ {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
+ *result, error_spec_);
+}
+
+TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.0, 2.0})));
+
+ // Broadcast vector in dimension 1.
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {2, 2, 3, 3}), input, {1}));
+
+ // Create HLO module, compile, and execute.
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ hlo_module->AddEntryComputation(builder.Build());
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+
+ Array4D<float> expected(2, 2, 3, 3);
+ Array2D<float> pz({{1, 2}, {1, 2}});
+ expected.FillWithPZ(pz);
+
+ LiteralTestUtil::ExpectNear(
+ *LiteralUtil::CreateR4FromArray4D<float>(expected), *result, error_spec_);
+}
+
+TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
+ auto builder = HloComputation::Builder(TestName());
+ std::vector<float> input_data(1025);
+ int64 r1_size = input_data.size();
+ std::iota(input_data.begin(), input_data.end(), 0.0f);
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(input_data)));
+
+ // Broadcast vector in dimension 3.
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {3, 3, 3, r1_size}), input, {3}));
+
+ // Create HLO module, compile, and execute.
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ hlo_module->AddEntryComputation(builder.Build());
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+
+ Array4D<float> expected(3, 3, 3, 1025);
+ Array2D<float> yx(/*height=*/3, /*width=*/r1_size);
+ for (int64 y = 0; y < 3; ++y) {
+ for (int64 x = 0; x < r1_size; ++x) {
+ yx(y, x) = input_data[x];
+ }
+ }
+ expected.FillWithYX(yx);
+
+ LiteralTestUtil::ExpectNear(
+ *LiteralUtil::CreateR4FromArray4D<float>(expected), *result, error_spec_);
+}
+
+XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
+ auto builder = HloComputation::Builder(TestName());
+ Array4D<float> r4_array(32, 64, 7, 7);
+ r4_array.Fill(42.0);
+ std::vector<float> r1_array(64, 42.0);
+
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(r1_array)));
+
+ // Broadcast vector in dimension 1.
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {32, 64, 7, 7}), input, {1}));
+
+ // Create HLO module, compile, and execute.
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ hlo_module->AddEntryComputation(builder.Build());
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+
+ LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR4FromArray4D(r4_array),
+ *result, error_spec_);
+}
+
+TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {}));
+
+ // Create HLO module, compile, and execute.
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ hlo_module->AddEntryComputation(builder.Build());
+ LOG(INFO) << hlo_module->ToString();
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+
+ Array4D<float> expected(64, 64, 3, 3);
+ expected.Fill(1.0f);
+
+ LiteralTestUtil::ExpectNear(
+ *LiteralUtil::CreateR4FromArray4D<float>(expected), *result, error_spec_);
+}
+
+TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
+ auto builder = HloComputation::Builder(TestName());
+ Array2D<float> to_broadcast({{1.0f, 2.0f}, {3.0f, 4.0f}});
+ auto input = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2FromArray2D<float>(to_broadcast)));
+
+ // Broadcast vector in dimensions 2 and 3.
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {3, 3, 2, 2}), input, {2, 3}));
+
+ // Create HLO module, compile, and execute.
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ hlo_module->AddEntryComputation(builder.Build());
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+
+ Array4D<float> expected(3, 3, 2, 2);
+ expected.FillWithYX(to_broadcast);
+
+ LiteralTestUtil::ExpectNear(
+ *LiteralUtil::CreateR4FromArray4D<float>(expected), *result, error_spec_);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl
new file mode 100644
index 0000000000..2c7eeb820d
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/build_defs.bzl
@@ -0,0 +1,149 @@
+"""Build rules for XLA testing."""
+
+load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
+
+def all_backends():
+ if cuda_is_configured():
+ return ["cpu", "cpu_parallel", "gpu"]
+ else:
+ return ["cpu", "cpu_parallel"]
+
+def xla_test(name,
+ srcs,
+ deps,
+ backends=[],
+ args=[],
+ tags=[],
+ copts=[],
+ backend_tags={},
+ backend_args={},
+ **kwargs):
+ """Generates cc_test targets for the given XLA backends.
+
+ This rule generates a cc_test target for one or more XLA backends and also
+ a platform-agnostic cc_library rule. The arguments are identical to cc_test
+ with two additions: 'backends' and 'backend_args'. 'backends' specifies the
+ backends to generate tests for ("cpu", "cpu_parallel", "gpu"), and
+ 'backend_args'/'backend_tags' specifies backend-specific args parameters to
+ use when generating the cc_test.
+
+ The name of the cc_tests are the provided name argument with the backend name
+ appended, and the cc_library target name is the provided name argument with
+ "_lib" appended. For example, if name parameter is "foo_test", then the cpu
+ test target will be "foo_test_cpu" and the cc_library target is "foo_lib".
+
+ The cc_library target can be used to link with other plugins outside of
+ xla_test.
+
+ The build rule also defines a test suite ${name} which includes the tests for
+ each of the supported backends.
+
+ Each generated cc_test target has a tag indicating which backend the test is
+ for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These
+ tags can be used to gather tests for a particular backend into a test_suite.
+
+ Examples:
+
+ # Generates the targets: foo_test_cpu and foo_test_gpu.
+ xla_test(
+ name = "foo_test",
+ srcs = ["foo_test.cc"],
+ backends = ["cpu", "gpu"],
+ deps = [...],
+ )
+
+ # Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu
+ # includes the additional arg "--special_cpu_flag".
+ xla_test(
+ name = "bar_test",
+ srcs = ["bar_test.cc"],
+ backends = ["cpu", "gpu"],
+ backend_args = {"cpu": ["--special_cpu_flag"]}
+ deps = [...],
+ )
+
+ The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND}
+ to the value 1 where ${BACKEND} is the uppercase name of the backend.
+
+ Args:
+ name: Name of the target.
+ srcs: Sources for the target.
+ deps: Dependencies of the target.
+ backends: A list of backends to generate tests for. Supported
+ values: "cpu", "cpu_parallel", "gpu". If this list is empty, the test will
+ be generated for all supported backends.
+ args: Test arguments for the target.
+ tags: Tags for the target.
+ backend_args: A dict mapping backend name to list of additional args to
+ use for that target.
+ backend_tags: A dict mapping backend name to list of additional tags to
+ use for that target.
+ """
+ test_names = []
+ if not backends:
+ backends = all_backends()
+
+ native.cc_library(
+ name="%s_lib" % name,
+ srcs=srcs,
+ copts=copts,
+ testonly=True,
+ deps=deps + ["//tensorflow/compiler/xla/tests:test_macros_header"],
+ )
+
+ for backend in backends:
+ test_name = "%s_%s" % (name, backend)
+ this_backend_tags = ["xla_%s" % backend]
+ this_backend_copts = []
+ this_backend_args = backend_args.get(backend, [])
+ if backend == "cpu":
+ backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
+ backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
+ elif backend == "cpu_parallel":
+ backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
+ backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
+ this_backend_args += ["--xla_cpu_parallel=true"]
+ elif backend == "gpu":
+ backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"]
+ backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"]
+ this_backend_tags += ["requires-gpu-sm35"]
+ else:
+ fail("Unknown backend %s" % backend)
+
+ native.cc_test(
+ name=test_name,
+ srcs=srcs,
+ tags=tags + backend_tags.get(backend, []) + this_backend_tags,
+ copts=copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
+ this_backend_copts,
+ args=args + this_backend_args,
+ deps=deps + backend_deps,
+ **kwargs)
+
+ test_names.append(test_name)
+
+ native.test_suite(name=name, tests=test_names)
+
+
+def generate_backend_suites(backends=[]):
+ if not backends:
+ backends = all_backends()
+ for backend in backends:
+ native.test_suite(name="%s_tests" % backend,
+ tags = ["xla_%s" % backend])
+
+
+def generate_backend_test_macros(backends=[]):
+ if not backends:
+ backends = all_backends()
+ for backend in backends:
+ native.cc_library(
+ name="test_macros_%s" % backend,
+ testonly = True,
+ hdrs = ["test_macros.h"],
+ copts = ["-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper()],
+ deps = [
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ])
diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc
new file mode 100644
index 0000000000..1c96b73034
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/call_test.cc
@@ -0,0 +1,115 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <utility>
+
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class CallOpTest : public ClientLibraryTestBase {
+ protected:
+ Computation CreateR0F32IdentityComputation() {
+ ComputationBuilder builder(client_, "Identity");
+ builder.Parameter(0, r0f32_, "x");
+ auto build_status = builder.Build();
+ EXPECT_IS_OK(build_status.status());
+ return build_status.ConsumeValueOrDie();
+ }
+
+ Computation CreateR1S0F32AdditionComputation() {
+ ComputationBuilder builder(client_, "Addition");
+ auto x = builder.Parameter(0, r1s0f32_, "x");
+ auto y = builder.Parameter(1, r1s0f32_, "y");
+ builder.Add(x, y);
+ auto build_status = builder.Build();
+ EXPECT_IS_OK(build_status.status());
+ return build_status.ConsumeValueOrDie();
+ }
+
+ Computation CreateR1S2F32AdditionComputation() {
+ ComputationBuilder builder(client_, "Addition");
+ auto x = builder.Parameter(0, r1s2f32_, "x");
+ auto y = builder.Parameter(1, r1s2f32_, "y");
+ builder.Add(x, y);
+ auto build_status = builder.Build();
+ EXPECT_IS_OK(build_status.status());
+ return build_status.ConsumeValueOrDie();
+ }
+
+ Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
+ Shape r1s0f32_ = ShapeUtil::MakeShape(F32, {0});
+ Shape r1s2f32_ = ShapeUtil::MakeShape(F32, {2});
+};
+
+XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR0F32IdentityScalar)) {
+ ComputationBuilder builder(client_, TestName());
+ Computation callee = CreateR0F32IdentityComputation();
+ auto constant = builder.ConstantLiteral(*LiteralUtil::CreateR0<float>(42.0));
+ builder.Call(callee, {constant});
+
+ ComputeAndCompareR0<float>(&builder, 42.0, {}, ErrorSpec(0.01f));
+}
+
+XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S0F32AddArray)) {
+ ComputationBuilder builder(client_, TestName());
+ Computation callee = CreateR1S0F32AdditionComputation();
+ auto x = builder.ConstantLiteral(*LiteralUtil::CreateR1<float>({}));
+ auto y = builder.ConstantLiteral(*LiteralUtil::CreateR1<float>({}));
+ builder.Call(callee, {x, y});
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.01f));
+}
+
+XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S2F32AddArray)) {
+ ComputationBuilder builder(client_, TestName());
+ Computation callee = CreateR1S2F32AdditionComputation();
+ auto x = builder.ConstantLiteral(*LiteralUtil::CreateR1<float>({1.0f, 2.0f}));
+ auto y = builder.ConstantLiteral(*LiteralUtil::CreateR1<float>({2.0f, 3.0f}));
+ builder.Call(callee, {x, y});
+
+ ComputeAndCompareR1<float>(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
new file mode 100644
index 0000000000..675c9fccb0
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
@@ -0,0 +1,138 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class CheckExecutionArityTest : public ClientLibraryTestBase {};
+
+TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) {
+ ComputationBuilder builder(client_, "add_two_params");
+ auto param_literal = LiteralUtil::CreateR1<float>({1.1f, 2.2f});
+
+ auto p0 = builder.Parameter(0, param_literal->shape(), "param0");
+ auto p1 = builder.Parameter(1, param_literal->shape(), "param1");
+ auto add = builder.Add(p0, p1);
+
+ auto param0_data =
+ client_->TransferToServer(*param_literal).ConsumeValueOrDie();
+ auto param1_data =
+ client_->TransferToServer(*param_literal).ConsumeValueOrDie();
+
+ auto computation_status = builder.Build();
+ ASSERT_IS_OK(computation_status.status());
+ auto computation = computation_status.ConsumeValueOrDie();
+
+ // The arity of the UserComputation is 2 arguments. Execution will succeed
+ // with 2 arguments, but fail with a different number.
+ auto result_two_args =
+ client_->Execute(computation, {param0_data.get(), param1_data.get()});
+ ASSERT_IS_OK(result_two_args.status());
+
+ auto result_one_arg = client_->Execute(computation, {param0_data.get()});
+ ASSERT_FALSE(result_one_arg.ok());
+ ASSERT_EQ(result_one_arg.status().code(),
+ tensorflow::error::INVALID_ARGUMENT);
+ ASSERT_MATCH(result_one_arg.status().error_message(),
+ testing::ContainsRegex("takes 2"));
+
+ auto result_zero_args = client_->Execute(computation, {});
+ ASSERT_FALSE(result_zero_args.ok());
+ ASSERT_EQ(result_zero_args.status().code(),
+ tensorflow::error::INVALID_ARGUMENT);
+ ASSERT_MATCH(result_zero_args.status().error_message(),
+ testing::ContainsRegex("takes 2"));
+}
+
+XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) {
+ ComputationBuilder builder(client_, "add_two_params");
+
+ auto p0 = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0");
+ auto p1 = builder.Parameter(1, ShapeUtil::MakeShape(F32, {4}), "param1");
+ auto add = builder.Mul(p0, p1);
+
+ auto computation_status = builder.Build();
+ ASSERT_IS_OK(computation_status.status());
+ auto computation = computation_status.ConsumeValueOrDie();
+
+ auto f32_literal = LiteralUtil::CreateR0<float>(1.1f);
+ auto f32_data = client_->TransferToServer(*f32_literal).ConsumeValueOrDie();
+ auto f32_4_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
+ auto f32_4_data =
+ client_->TransferToServer(*f32_4_literal).ConsumeValueOrDie();
+ auto u8_4_literal = LiteralUtil::CreateR1U8("hola");
+ auto u8_4_data = client_->TransferToServer(*u8_4_literal).ConsumeValueOrDie();
+
+ // Match
+ auto status =
+ client_->Execute(computation, {f32_data.get(), f32_4_data.get()});
+ ASSERT_IS_OK(status.status());
+
+ // Shape mismatch in parameter 0
+ status = client_->Execute(computation, {f32_4_data.get(), f32_4_data.get()});
+ ASSERT_FALSE(status.ok());
+ ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT);
+ ASSERT_MATCH(status.status().error_message(),
+ testing::ContainsRegex("expects parameter 0"));
+
+ // Shape mismatch in parameter 1 (rank)
+ status = client_->Execute(computation, {f32_data.get(), f32_data.get()});
+ ASSERT_FALSE(status.ok());
+ ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT);
+ ASSERT_MATCH(status.status().error_message(),
+ testing::ContainsRegex("expects parameter 1"));
+
+ // Shape mismatch in parameter 1 (element type)
+ status = client_->Execute(computation, {f32_data.get(), u8_4_data.get()});
+ ASSERT_FALSE(status.ok());
+ ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT);
+ ASSERT_MATCH(status.status().error_message(),
+ testing::ContainsRegex("expects parameter 1"));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
new file mode 100644
index 0000000000..d2a7def5d0
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -0,0 +1,263 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+
+#include <string>
+
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+namespace {
+// Wrapper function that creates a nicer error message (than a bare
+// ValueOrDie()) if the platform we intend to test is not available.
+Client* GetOrCreateLocalClientOrDie(se::Platform* platform) {
+ StatusOr<Client*> result = ClientLibrary::GetOrCreateLocalClient(platform);
+ TF_CHECK_OK(result.status()) << "could not create local client for testing";
+ return result.ValueOrDie();
+}
+} // namespace
+
+ClientLibraryTestBase::ClientLibraryTestBase(
+ se::Platform* platform,
+ tensorflow::gtl::ArraySlice<string> disabled_pass_names)
+ : client_(GetOrCreateLocalClientOrDie(platform)) {
+ legacy_flags::HloPassPipelineFlags* flags =
+ legacy_flags::GetHloPassPipelineFlags();
+ flags->xla_disable_hlo_passes =
+ tensorflow::str_util::Join(disabled_pass_names, ",");
+}
+
+string ClientLibraryTestBase::TestName() const {
+ return ::testing::UnitTest::GetInstance()->current_test_info()->name();
+}
+
+StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
+ ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ // Build the computation, as a convenience.
+ TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
+ return client_->Execute(computation, arguments);
+}
+
+StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
+ ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const Shape* shape_with_output_layout) {
+ // Build the computation, as a convenience.
+ TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
+ return client_->ExecuteAndTransfer(computation, arguments,
+ shape_with_output_layout);
+}
+
+std::unique_ptr<GlobalData> ClientLibraryTestBase::ExecuteOrDie(
+ ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ return Execute(builder, arguments).ConsumeValueOrDie();
+}
+
+std::unique_ptr<Literal> ClientLibraryTestBase::ExecuteAndTransferOrDie(
+ ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ return ExecuteAndTransfer(builder, arguments).ConsumeValueOrDie();
+}
+
+string ClientLibraryTestBase::ExecuteToString(
+ ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ StatusOr<Computation> computation_status = builder->Build();
+ if (!computation_status.ok()) {
+ return computation_status.status().ToString();
+ }
+ Computation computation = computation_status.ConsumeValueOrDie();
+
+ auto result = client_->ExecuteAndTransfer(computation, arguments);
+ if (!result.ok()) {
+ return result.status().ToString();
+ } else {
+ return LiteralUtil::ToString(*result.ValueOrDie());
+ }
+}
+
+void ClientLibraryTestBase::ComputeAndCompareR1(
+ ComputationBuilder* builder, const tensorflow::core::Bitmap& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ arguments);
+}
+
+void ClientLibraryTestBase::ComputeAndCompareLiteral(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const Shape* shape_with_layout) {
+ EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments,
+ shape_with_layout));
+}
+
+void ClientLibraryTestBase::ComputeAndCompareLiteral(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
+ const Shape* shape_with_layout) {
+ EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments,
+ error, shape_with_layout));
+}
+
+tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const Shape* shape_with_layout) {
+ TF_ASSIGN_OR_RETURN(
+ auto actual, ExecuteAndTransfer(builder, arguments, shape_with_layout));
+ if (ShapeUtil::ElementIsFloating(expected.shape())) {
+ LOG(WARNING) << "performing exact comparison of floating point numbers";
+ } else {
+ TF_RET_CHECK(ShapeUtil::ElementIsIntegral(expected.shape()) ||
+ expected.shape().element_type() == PRED);
+ }
+ LiteralTestUtil::ExpectEqual(expected, *actual);
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
+ const Shape* shape_with_layout) {
+ TF_ASSIGN_OR_RETURN(
+ auto actual, ExecuteAndTransfer(builder, arguments, shape_with_layout));
+ TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()));
+ LiteralTestUtil::ExpectNear(expected, *actual, error);
+ return tensorflow::Status::OK();
+}
+
+void ClientLibraryTestBase::ComputeAndCompareR1U8(
+ ComputationBuilder* builder, tensorflow::StringPiece expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ auto actual_status = ExecuteAndTransfer(builder, arguments);
+ EXPECT_IS_OK(actual_status.status());
+ if (!actual_status.ok()) {
+ return;
+ }
+ auto actual = actual_status.ConsumeValueOrDie();
+
+ // Turn the expected value into a literal.
+ std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1U8(expected);
+
+ VLOG(1) << "expected: " << LiteralUtil::ToString(*expected_literal);
+ VLOG(1) << "actual: " << LiteralUtil::ToString(*actual);
+
+ EXPECT_EQ(expected, actual->u8s());
+}
+
+void ClientLibraryTestBase::ComputeAndCompareTuple(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ auto actual_status = ExecuteAndTransfer(builder, arguments);
+ EXPECT_IS_OK(actual_status.status());
+ if (!actual_status.ok()) {
+ return;
+ }
+ auto actual = actual_status.ConsumeValueOrDie();
+ LiteralTestUtil::ExpectEqualTuple(expected, *actual);
+}
+
+void ClientLibraryTestBase::ComputeAndCompareTuple(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
+ auto actual_status = ExecuteAndTransfer(builder, arguments);
+ EXPECT_IS_OK(actual_status.status());
+ if (!actual_status.ok()) {
+ return;
+ }
+ auto actual = actual_status.ConsumeValueOrDie();
+ LiteralTestUtil::ExpectNearTuple(expected, *actual, error);
+}
+
+Computation ClientLibraryTestBase::CreateScalarRelu() {
+ ComputationBuilder builder(client_, "relu");
+ auto z_value = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "z_value");
+ auto zero = builder.ConstantR0<float>(0.0);
+ builder.Max(z_value, zero);
+ auto computation_status = builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ return computation_status.ConsumeValueOrDie();
+}
+
+Computation ClientLibraryTestBase::CreateScalarMax() {
+ ComputationBuilder builder(client_, "max");
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
+ builder.Max(x, y);
+ auto computation_status = builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ return computation_status.ConsumeValueOrDie();
+}
+
+Computation ClientLibraryTestBase::CreateScalarReluSensitivity() {
+ ComputationBuilder builder(client_, "relu_sensitivity");
+ auto activation =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "activation");
+ auto backprop =
+ builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "backprop");
+ auto zero = builder.ConstantR0<float>(0.0);
+ auto activation_gtz = builder.Gt(activation, zero);
+ builder.Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero);
+
+ auto computation_status = builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ return computation_status.ConsumeValueOrDie();
+}
+
+std::unique_ptr<Array2D<float>> ClientLibraryTestBase::CreatePatternedMatrix(
+ int rows, int cols, float offset) {
+ auto array = MakeUnique<Array2D<float>>(rows, cols);
+ for (int64 row = 0; row < rows; ++row) {
+ for (int64 col = 0; col < cols; ++col) {
+ (*array)(row, col) = col + (row * 1000.0f) + offset;
+ }
+ }
+ return array;
+}
+
+std::unique_ptr<Array2D<float>>
+ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols,
+ int rows_padded,
+ int cols_padded) {
+ CHECK_GE(rows_padded, rows);
+ CHECK_GE(cols_padded, cols);
+ auto array = MakeUnique<Array2D<float>>(rows_padded, cols_padded, 0.0);
+ for (int64 row = 0; row < rows; ++row) {
+ for (int64 col = 0; col < cols; ++col) {
+ (*array)(row, col) = col + (row * 1000.0f);
+ }
+ }
+ return array;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
new file mode 100644
index 0000000000..690fda3ffa
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -0,0 +1,409 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
+#define TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
+
+#include <memory>
+#include <string>
+#include <type_traits>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/bitmap.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// A client library test establishes an in-process XLA client connection.
+class ClientLibraryTestBase : public ::testing::Test {
+ protected:
+ explicit ClientLibraryTestBase(
+ perftools::gputools::Platform* platform = nullptr,
+ tensorflow::gtl::ArraySlice<string> disabled_pass_names = {});
+
+ // Returns the name of the test currently being run.
+ string TestName() const;
+
+ // TODO(b/25566808): Add helper that populates a literal from a testdata file.
+
+ // Convenience methods for building and running a computation from a builder.
+ StatusOr<std::unique_ptr<GlobalData>> Execute(
+ ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
+ ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const Shape* shape_with_output_layout = nullptr);
+
+ // Convenience OrDie variants of above methods.
+ std::unique_ptr<GlobalData> ExecuteOrDie(
+ ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ std::unique_ptr<Literal> ExecuteAndTransferOrDie(
+ ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+
+ // Run a computation and return its value as a string. If an error
+ // occurs, then instead return the error as a string.
+ string ExecuteToString(ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+
+ // Convenience methods for building and running a computation, transferring
+ // the result, and comparing it to the expected value(s). Methods are
+ // templated on the native host type which maps to specific XLA types (See
+ // ComputationBuilder for details). For each rank, two forms are provided: one
+ // for floating point types with an ErrorSpec parameter, and one for integral
+ // types without the ErrorSpec parameter.
+ template <typename NativeT>
+ void ComputeAndCompareR0(ComputationBuilder* builder, NativeT expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ template <typename NativeT>
+ void ComputeAndCompareR0(ComputationBuilder* builder, NativeT expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ ErrorSpec error);
+
+ template <typename NativeT>
+ void ComputeAndCompareR1(ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<NativeT> expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ template <typename NativeT>
+ void ComputeAndCompareR1(ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<NativeT> expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ ErrorSpec error);
+
+ // As above, but uses a bitmap to hold the predicate vector to avoid
+ // deficiencies of vector<bool>.
+ void ComputeAndCompareR1(ComputationBuilder* builder,
+ const tensorflow::core::Bitmap& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+
+ template <typename NativeT>
+ void ComputeAndCompareR2(ComputationBuilder* builder,
+ const Array2D<NativeT>& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ template <typename NativeT>
+ void ComputeAndCompareR2(ComputationBuilder* builder,
+ const Array2D<NativeT>& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ ErrorSpec error);
+
+ template <typename NativeT>
+ void ComputeAndCompareR3(ComputationBuilder* builder,
+ const Array3D<NativeT>& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ template <typename NativeT>
+ void ComputeAndCompareR3(ComputationBuilder* builder,
+ const Array3D<NativeT>& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ ErrorSpec error);
+
+ template <typename NativeT>
+ void ComputeAndCompareR4(ComputationBuilder* builder,
+ const Array4D<NativeT>& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ template <typename NativeT>
+ void ComputeAndCompareR4(ComputationBuilder* builder,
+ const Array4D<NativeT>& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ ErrorSpec error);
+
+ // Build and run the computation and compare the result with the given
+ // literal. shape_with_layout indicates the result layout to request when
+ // calling Execute.
+ void ComputeAndCompareLiteral(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const Shape* shape_with_layout = nullptr);
+ void ComputeAndCompareLiteral(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
+ const Shape* shape_with_layout = nullptr);
+
+ // ComputeAndCompare variant which returns an error status.
+ tensorflow::Status ComputeAndCompareLiteralWithStatus(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const Shape* shape_with_layout = nullptr);
+ tensorflow::Status ComputeAndCompareLiteralWithStatus(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
+ const Shape* shape_with_layout = nullptr);
+
+ // Compare the result of the computation to a strings. In XLA strings are
+ // represented using rank-1 U8 shapes.
+ void ComputeAndCompareR1U8(
+ ComputationBuilder* builder, tensorflow::StringPiece expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+
+ // Convenience method for running a built computation, transferring the
+ // result, and comparing it to the expected tuple literal.
+ void ComputeAndCompareTuple(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ void ComputeAndCompareTuple(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec abs_error);
+
+ // Create scalar operations for use in reductions.
+ Computation CreateScalarRelu();
+ Computation CreateScalarMax();
+ Computation CreateScalarReluSensitivity();
+
+ // Special case convenience functions for creating filled arrays.
+
+ // Creates an array of pseudorandom values lying between the given minimum and
+ // maximum values.
+ template <typename NativeT>
+ std::vector<NativeT> CreatePseudorandomR1(const int width, NativeT min_value,
+ NativeT max_value, uint32 seed);
+ template <typename NativeT>
+ std::unique_ptr<Array2D<NativeT>> CreatePseudorandomR2(const int rows,
+ const int cols,
+ NativeT min_value,
+ NativeT max_value,
+ uint32 seed);
+
+ // Creates a (rows x cols) array filled in the following form:
+ //
+ // [ 0 1 ... cols-1]
+ // [ 1,000 1,001 ... 1000.0 + cols-1]
+ // [ ... ... ... ...]
+ // [(rows-1)*1000.0 ... ... (rows-1)*1000.0 + cols-1]
+ //
+ // If provided, offset is added uniformly to every element (e.g. an offset of
+ // 64 would cause 0 in the above to be 64, 1 to be 65, 1000 to be 1064, etc.)
+ std::unique_ptr<Array2D<float>> CreatePatternedMatrix(const int rows,
+ const int cols,
+ float offset = 0.0);
+
+ // Creates a (rows x cols) array as above, padded out to
+ // (rows_padded x cols_padded) with zeroes. Requires rows_padded >= rows
+ // and cols_padded > cols.
+ std::unique_ptr<Array2D<float>> CreatePatternedMatrixWithZeroPadding(
+ const int rows, const int cols, const int rows_padded,
+ const int cols_padded);
+
+ // Create a parameter instruction that wraps the given values and then stores
+ // into "data_handle" the global handle for that parameter.
+ //
+ // "parameter_number" is the parameter number.
+ // "name" is the name of the parameter instruction.
+ template <typename NativeT>
+ std::unique_ptr<GlobalData> CreateR1Parameter(
+ tensorflow::gtl::ArraySlice<NativeT> values, int64 parameter_number,
+ const string& name, ComputationBuilder* builder,
+ ComputationDataHandle* data_handle);
+
+ // Create a parameter instruction that wraps the given constant array
+ // "array_2d" and then stores to "data_handle" the global handle for that
+ // parameter.
+ //
+ // "parameter_number" is the parameter number.
+ // "name" is the name of the parameter instruction.
+ template <typename NativeT>
+ std::unique_ptr<GlobalData> CreateR2Parameter(
+ const Array2D<NativeT>& array_2d, int64 parameter_number,
+ const string& name, ComputationBuilder* builder,
+ ComputationDataHandle* data_handle);
+
+ Client* client_;
+};
+
+template <typename NativeT>
+void ClientLibraryTestBase::ComputeAndCompareR0(
+ ComputationBuilder* builder, NativeT expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ std::unique_ptr<Literal> expected_literal =
+ LiteralUtil::CreateR0<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ arguments);
+}
+
+template <typename NativeT>
+void ClientLibraryTestBase::ComputeAndCompareR0(
+ ComputationBuilder* builder, NativeT expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
+ static_assert(std::is_same<NativeT, float>::value ||
+ std::is_same<NativeT, double>::value,
+ "Floating point type required when specifying an ErrorSpec");
+ std::unique_ptr<Literal> expected_literal =
+ LiteralUtil::CreateR0<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ arguments, error);
+}
+
+template <typename NativeT>
+void ClientLibraryTestBase::ComputeAndCompareR1(
+ ComputationBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ std::unique_ptr<Literal> expected_literal =
+ LiteralUtil::CreateR1<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ arguments);
+}
+
+template <typename NativeT>
+void ClientLibraryTestBase::ComputeAndCompareR1(
+ ComputationBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
+ static_assert(std::is_same<NativeT, float>::value ||
+ std::is_same<NativeT, double>::value,
+ "Floating point type required when specifying an ErrorSpec");
+ std::unique_ptr<Literal> expected_literal =
+ LiteralUtil::CreateR1<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ arguments, error);
+}
+
+template <typename NativeT>
+void ClientLibraryTestBase::ComputeAndCompareR2(
+ ComputationBuilder* builder, const Array2D<NativeT>& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ std::unique_ptr<Literal> expected_literal =
+ LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ arguments);
+}
+
+template <typename NativeT>
+void ClientLibraryTestBase::ComputeAndCompareR2(
+ ComputationBuilder* builder, const Array2D<NativeT>& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
+ static_assert(std::is_same<NativeT, float>::value ||
+ std::is_same<NativeT, double>::value,
+ "Floating point type required when specifying an ErrorSpec");
+ std::unique_ptr<Literal> expected_literal =
+ LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ arguments, error);
+}
+
+template <typename NativeT>
+void ClientLibraryTestBase::ComputeAndCompareR3(
+ ComputationBuilder* builder, const Array3D<NativeT>& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ std::unique_ptr<Literal> expected_literal =
+ LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ arguments);
+}
+
+template <typename NativeT>
+void ClientLibraryTestBase::ComputeAndCompareR3(
+ ComputationBuilder* builder, const Array3D<NativeT>& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
+ static_assert(std::is_same<NativeT, float>::value ||
+ std::is_same<NativeT, double>::value,
+ "Floating point type required when specifying an ErrorSpec");
+ std::unique_ptr<Literal> expected_literal =
+ LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ arguments, error);
+}
+
+template <typename NativeT>
+void ClientLibraryTestBase::ComputeAndCompareR4(
+ ComputationBuilder* builder, const Array4D<NativeT>& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ std::unique_ptr<Literal> expected_literal =
+ LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ arguments);
+}
+
+template <typename NativeT>
+void ClientLibraryTestBase::ComputeAndCompareR4(
+ ComputationBuilder* builder, const Array4D<NativeT>& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
+ static_assert(std::is_same<NativeT, float>::value ||
+ std::is_same<NativeT, double>::value,
+ "Floating point type required when specifying an ErrorSpec");
+ std::unique_ptr<Literal> expected_literal =
+ LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ arguments, error);
+}
+
+template <typename NativeT>
+std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
+ tensorflow::gtl::ArraySlice<NativeT> values, int64 parameter_number,
+ const string& name, ComputationBuilder* builder,
+ ComputationDataHandle* data_handle) {
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1(values);
+ std::unique_ptr<GlobalData> data =
+ client_->TransferToServer(*literal).ConsumeValueOrDie();
+ *data_handle = builder->Parameter(parameter_number, literal->shape(), name);
+ return data;
+}
+
+template <typename NativeT>
+std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
+ const Array2D<NativeT>& array_2d, int64 parameter_number,
+ const string& name, ComputationBuilder* builder,
+ ComputationDataHandle* data_handle) {
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR2FromArray2D(array_2d);
+ std::unique_ptr<GlobalData> data =
+ client_->TransferToServer(*literal).ConsumeValueOrDie();
+ *data_handle = builder->Parameter(parameter_number, literal->shape(), name);
+ return data;
+}
+
+template <typename NativeT>
+std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1(
+ const int width, NativeT min_value, NativeT max_value, uint32 seed) {
+ std::vector<NativeT> result(width);
+ test_utils::PseudorandomGenerator<NativeT> generator(min_value, max_value,
+ seed);
+ for (int i = 0; i < width; ++i) {
+ result[i] = generator.get();
+ }
+ return result;
+}
+
+template <typename NativeT>
+std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2(
+ const int rows, const int cols, NativeT min_value, NativeT max_value,
+ uint32 seed) {
+ auto result = MakeUnique<Array2D<NativeT>>(rows, cols);
+ test_utils::PseudorandomGenerator<NativeT> generator(min_value, max_value,
+ seed);
+ for (int y = 0; y < rows; ++y) {
+ for (int x = 0; x < cols; ++x) {
+ (*result)(y, x) = generator.get();
+ }
+ }
+ return result;
+}
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc
new file mode 100644
index 0000000000..77b85af83c
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/client_test.cc
@@ -0,0 +1,127 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ClientTest : public ClientLibraryTestBase {};
+
+TEST_F(ClientTest, ExecuteWithLayout) {
+ ComputationBuilder b(client_, TestName());
+
+ std::vector<std::vector<int64>> layouts = {{0, 1}, {1, 0}};
+ for (const std::vector<int64>& execute_layout : layouts) {
+ for (const std::vector<int64>& transfer_layout : layouts) {
+ b.Add(b.ConstantR2<int32>({{1, 2}, {3, 4}}),
+ b.ConstantR2<int32>({{10, 20}, {30, 40}}));
+ auto computation = b.Build();
+ ASSERT_TRUE(computation.ok()) << computation.status();
+
+ const Shape execute_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
+ S32, /*dimensions=*/{2, 2}, execute_layout);
+ std::unique_ptr<GlobalData> data =
+ client_
+ ->Execute(computation.ValueOrDie(), {},
+ &execute_shape_with_layout)
+ .ConsumeValueOrDie();
+
+ std::unique_ptr<Literal> expected_literal =
+ test_utils::CreateR2LiteralWithLayout<int32>({{11, 22}, {33, 44}},
+ transfer_layout);
+
+ auto computed = client_->Transfer(*data, &expected_literal->shape());
+
+ LiteralTestUtil::AssertEqualShapesAndLayouts(
+ expected_literal->shape(), computed.ValueOrDie()->shape());
+ LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
+ }
+ }
+}
+
+TEST_F(ClientTest, ExecuteWithTupleLayout) {
+ ComputationBuilder b(client_, TestName());
+
+ b.Tuple({b.ConstantR2<int32>({{1, 2}, {3, 4}}),
+ b.ConstantR2<int32>({{10, 20}, {30, 40}})});
+
+ auto computation = b.Build();
+ ASSERT_TRUE(computation.ok()) << computation.status();
+
+ // Create a result shape with one element column major and the other row
+ // major.
+ Shape tuple_shape_with_layout = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
+ /*minor_to_major=*/{0, 1}),
+ ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
+ /*minor_to_major=*/{1, 0})});
+
+ auto result = client_
+ ->ExecuteAndTransfer(computation.ValueOrDie(), {},
+ &tuple_shape_with_layout)
+ .ConsumeValueOrDie();
+ LiteralTestUtil::ExpectR2Equal<int32>({{1, 2}, {3, 4}},
+ result->tuple_literals(0));
+ LiteralTestUtil::ExpectR2Equal<int32>({{10, 20}, {30, 40}},
+ result->tuple_literals(1));
+
+ EXPECT_TRUE(ShapeUtil::IsTuple(result->shape()));
+ EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape()));
+
+ EXPECT_TRUE(ShapeUtil::Equal(
+ ShapeUtil::GetTupleElementShape(result->shape(), 0),
+ ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
+ /*minor_to_major=*/{0, 1})));
+ EXPECT_TRUE(ShapeUtil::Equal(
+ ShapeUtil::GetTupleElementShape(result->shape(), 1),
+ ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
+ /*minor_to_major=*/{1, 0})));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.cc b/tensorflow/compiler/xla/tests/codegen_test_base.cc
new file mode 100644
index 0000000000..fe4dff2109
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/codegen_test_base.cc
@@ -0,0 +1,90 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/codegen_test_base.h"
+
+#include <stdlib.h>
+#include <utility>
+
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/subprocess.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+
+void CodegenTestBase::CompileAndVerifyIr(std::unique_ptr<HloModule> hlo_module,
+ const string& pattern) {
+ std::unique_ptr<Executable> executable =
+ CompileToExecutable(std::move(hlo_module));
+ string ir_module_string = GetIrFromExecutable(*executable);
+ RunFileCheck(ir_module_string, pattern);
+}
+
+std::unique_ptr<Executable> CodegenTestBase::CompileToExecutable(
+ std::unique_ptr<HloModule> hlo_module) {
+ auto module_config = MakeUnique<HloModuleConfig>(
+ MakeProgramShape(hlo_module->entry_computation()));
+ return backend_->compiler()
+ ->Compile(std::move(hlo_module), std::move(module_config),
+ test_hlo_dumper_, backend_->default_stream_executor())
+ .ConsumeValueOrDie();
+}
+
+void CodegenTestBase::RunFileCheck(const string& input, const string& pattern) {
+ // Write input to a temporary file.
+ char tempdir_template[] = "/tmp/ir_testXXXXXX";
+ char* tempdir_name = mkdtemp(tempdir_template);
+ CHECK_NOTNULL(tempdir_name);
+ string pattern_path =
+ tensorflow::io::JoinPath(tempdir_name, "xla_hlo_test_ir_pattern");
+ TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(),
+ pattern_path, pattern));
+
+ // Invoke FileCheck to check whether input matches `pattern`.
+ tensorflow::SubProcess file_check_process;
+ const char* test_srcdir = getenv("TEST_SRCDIR");
+ if (test_srcdir == nullptr) {
+ test_srcdir = ".";
+ }
+ string file_check_path = tensorflow::io::JoinPath(
+ test_srcdir, "external/llvm/FileCheck");
+ file_check_process.SetProgram(file_check_path,
+ {file_check_path, pattern_path});
+ file_check_process.SetChannelAction(tensorflow::CHAN_STDIN,
+ tensorflow::ACTION_PIPE);
+ file_check_process.SetChannelAction(tensorflow::CHAN_STDERR,
+ tensorflow::ACTION_PIPE);
+ CHECK(file_check_process.Start());
+ string standard_error;
+ int exit_status = file_check_process.Communicate(
+ /*stdin_input=*/&input, /*stdout_output=*/nullptr,
+ /*stderr_output=*/&standard_error);
+
+ // FileCheck returns 0 when the inputs match. If matching failed, we output
+ // the error message generated by FileCheck.
+ SCOPED_TRACE(tensorflow::strings::StrCat("Input to FileCheck:\n", input));
+ EXPECT_EQ(0, exit_status) << standard_error;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.h b/tensorflow/compiler/xla/tests/codegen_test_base.h
new file mode 100644
index 0000000000..50c0453107
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/codegen_test_base.h
@@ -0,0 +1,56 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_TESTS_CODEGEN_TEST_BASE_H_
+#define TENSORFLOW_COMPILER_XLA_TESTS_CODEGEN_TEST_BASE_H_
+
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+
+namespace xla {
+
+// Tests that verify IR emitted by the CPU/GPU backend is as expected.
+class CodegenTestBase : public HloTestBase {
+ protected:
+ CodegenTestBase() {}
+
+ // Returns the embedded LLVM IR from the given executable. Codegen tests must
+ // override this method, but execution tests do not have to because they do
+ // not examine the embedded IR.
+ virtual string GetIrFromExecutable(const Executable& executable) = 0;
+
+ // Compiles the given HLO module to LLVM IR and verifies the IR matches the
+ // given pattern. `pattern` is in the FileCheck pattern matching syntax
+ // (http://llvm.org/docs/CommandGuide/FileCheck.html).
+ void CompileAndVerifyIr(std::unique_ptr<HloModule> hlo_module,
+ const string& pattern);
+
+ protected:
+ // Compiles hlo_module to an executable, CHECK-failing if this fails.
+ std::unique_ptr<Executable> CompileToExecutable(
+ std::unique_ptr<HloModule> hlo_module);
+
+ // Runs FileCheck with the given pattern over the given string and EXPECTs
+ // that FileCheck succeeded in matching the input.
+ void RunFileCheck(const string& input, const string& pattern);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_TESTS_CODEGEN_TEST_BASE_H_
diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
new file mode 100644
index 0000000000..38ce007cb0
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
@@ -0,0 +1,218 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <initializer_list>
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/compiler/xla/xla.pb.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class CompilationCacheTest : public ClientLibraryTestBase {
+ public:
+ void ExecuteComputationR0F32(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, float expected_result,
+ bool expect_cache_hit) {
+ ExecutionProfile execution_profile;
+ std::unique_ptr<Literal> result =
+ client_
+ ->ExecuteAndTransfer(computation, arguments,
+ /*output_layout=*/nullptr, &execution_profile)
+ .ConsumeValueOrDie();
+ LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR0<float>(expected_result),
+ *result, error_spec_);
+ EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
+ }
+
+ void ExecuteComputationR2F32(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ std::initializer_list<std::initializer_list<float>> expected_result,
+ bool expect_cache_hit) {
+ ExecutionProfile execution_profile;
+ auto data_handle =
+ client_
+ ->Execute(computation, arguments, /*output_layout=*/nullptr,
+ &execution_profile)
+ .ConsumeValueOrDie();
+ std::unique_ptr<Literal> result =
+ client_->Transfer(*data_handle).ConsumeValueOrDie();
+ LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR2<float>(expected_result),
+ *result, error_spec_);
+ EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
+ }
+
+ ErrorSpec error_spec_{0.0001};
+};
+
+XLA_TEST_F(CompilationCacheTest, ComputationCalledMultipleTimes) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Neg(builder.ConstantR0<float>(42.0));
+ Computation computation = builder.Build().ConsumeValueOrDie();
+
+ ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false);
+ ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true);
+ ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true);
+}
+
+XLA_TEST_F(CompilationCacheTest, ComputationCalledWithDifferentParameters) {
+ std::unique_ptr<GlobalData> data_42 =
+ client_->TransferToServer(*LiteralUtil::CreateR0<float>(42.0f))
+ .ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> data_123 =
+ client_->TransferToServer(*LiteralUtil::CreateR0<float>(123.0f))
+ .ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> data_456 =
+ client_->TransferToServer(*LiteralUtil::CreateR0<float>(456.0f))
+ .ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ builder.Neg(builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"));
+ Computation computation = builder.Build().ConsumeValueOrDie();
+
+ ExecuteComputationR0F32(computation, {data_42.get()}, -42.0,
+ /*expect_cache_hit=*/false);
+ ExecuteComputationR0F32(computation, {data_123.get()}, -123.0,
+ /*expect_cache_hit=*/true);
+ ExecuteComputationR0F32(computation, {data_456.get()}, -456.0,
+ /*expect_cache_hit=*/true);
+ ExecuteComputationR0F32(computation, {data_42.get()}, -42.0,
+ /*expect_cache_hit=*/true);
+}
+
+XLA_TEST_F(CompilationCacheTest, MultipleComputations) {
+ ComputationBuilder builder_neg(client_, TestName() + "_neg");
+ builder_neg.Neg(builder_neg.ConstantR0<float>(42.0));
+ Computation computation_neg = builder_neg.Build().ConsumeValueOrDie();
+
+ ComputationBuilder builder_exp(client_, TestName() + "_exp");
+ builder_exp.Exp(builder_exp.ConstantR0<float>(1.0));
+ Computation computation_exp = builder_exp.Build().ConsumeValueOrDie();
+
+ ComputationBuilder builder_add(client_, TestName() + "_add");
+ builder_add.Add(builder_add.ConstantR0<float>(2.0),
+ builder_add.ConstantR0<float>(3.0));
+ Computation computation_add = builder_add.Build().ConsumeValueOrDie();
+
+ ExecuteComputationR0F32(computation_neg, {}, -42.0,
+ /*expect_cache_hit=*/false);
+ ExecuteComputationR0F32(computation_exp, {}, 2.7182817,
+ /*expect_cache_hit=*/false);
+ ExecuteComputationR0F32(computation_add, {}, 5.0,
+ /*expect_cache_hit=*/false);
+ ExecuteComputationR0F32(computation_neg, {}, -42.0,
+ /*expect_cache_hit=*/true);
+}
+
+XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) {
+ // Create two GlobalData arrays with the same shape but different
+ // layouts. Use these arrays as parameters to a simple computation. If the
+ // layout of the array changes then computation should be recompiled (cache
+ // miss).
+ auto rowmaj_array = test_utils::CreateR2LiteralWithLayout(
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{1, 0});
+ auto rowmaj_handle =
+ client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie();
+
+ auto colmaj_array = test_utils::CreateR2LiteralWithLayout(
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{0, 1});
+ auto colmaj_handle =
+ client_->TransferToServer(*colmaj_array).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "param0");
+ Computation computation = builder.Build().ConsumeValueOrDie();
+
+ ExecuteComputationR2F32(computation, {colmaj_handle.get()},
+ {{1.0f, 2.0f}, {3.0f, 4.0f}},
+ /*expect_cache_hit=*/false);
+ ExecuteComputationR2F32(computation, {colmaj_handle.get()},
+ {{1.0f, 2.0f}, {3.0f, 4.0f}},
+ /*expect_cache_hit=*/true);
+ ExecuteComputationR2F32(computation, {rowmaj_handle.get()},
+ {{1.0f, 2.0f}, {3.0f, 4.0f}},
+ /*expect_cache_hit=*/false);
+ ExecuteComputationR2F32(computation, {rowmaj_handle.get()},
+ {{1.0f, 2.0f}, {3.0f, 4.0f}},
+ /*expect_cache_hit=*/true);
+ ExecuteComputationR2F32(computation, {colmaj_handle.get()},
+ {{1.0f, 2.0f}, {3.0f, 4.0f}},
+ /*expect_cache_hit=*/true);
+}
+
+XLA_TEST_F(CompilationCacheTest, MutatedComputation) {
+ // Build a computation, execute it, then mutate it. The mutated computation
+ // should not be in the cache until it is run once. This must be done through
+ // the stub interface because Computations built from ComputationBuilder are
+ // immutable.
+ ComputationBuilder builder(client_, TestName());
+ auto neg = builder.Neg(builder.ConstantR0<float>(42.0));
+ Computation computation = builder.Build().ConsumeValueOrDie();
+
+ ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false);
+ ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true);
+
+ BinaryOpRequest request;
+ request.set_binop(BINOP_ADD);
+ *request.mutable_lhs() = neg;
+ *request.mutable_rhs() = neg;
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation.handle();
+ *op_request.mutable_binary_op_request() = request;
+ OpResponse response;
+ tensorflow::Status s = client_->stub()->Op(&op_request, &response);
+ ASSERT_TRUE(s.ok());
+
+ ExecuteComputationR0F32(computation, {}, -84.0, /*expect_cache_hit=*/false);
+ ExecuteComputationR0F32(computation, {}, -84.0, /*expect_cache_hit=*/true);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc
new file mode 100644
index 0000000000..709ce5029c
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc
@@ -0,0 +1,249 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ComputeConstantTest : public ClientLibraryTestBase {
+ public:
+ StatusOr<std::unique_ptr<Literal>> ComputeConstantLiteral(
+ ComputationDataHandle operand, ComputationBuilder* builder,
+ Layout* output_layout = nullptr) {
+ TF_ASSIGN_OR_RETURN(auto remote_computed,
+ builder->ComputeConstant(operand, output_layout));
+ TF_ASSIGN_OR_RETURN(auto computed, client_->Transfer(*remote_computed));
+ return std::move(computed);
+ }
+
+ template <class Scalar>
+ StatusOr<Scalar> ComputeConstantScalar(ComputationDataHandle operand,
+ ComputationBuilder* builder) {
+ TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(operand, builder));
+ return LiteralUtil::Get<Scalar>(*literal, {});
+ }
+
+ bool IsConstant(const ComputationDataHandle& operand,
+ ComputationBuilder* builder) {
+ StatusOr<bool> result = builder->IsConstant(operand);
+ EXPECT_TRUE(result.ok()) << result.status();
+ return result.ok() ? result.ValueOrDie() : false;
+ }
+
+ template <class Scalar>
+ void ExpectConstantComputedScalar(ComputationDataHandle operand,
+ Scalar expected,
+ ComputationBuilder* builder) {
+ Scalar computed = ComputeConstantScalar<Scalar>(operand, builder);
+ ASSERT_TRUE(computed.ok()) << computed.status();
+ std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR0(expected);
+ LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
+ }
+};
+
+TEST_F(ComputeConstantTest, ScalarInt32Literal) {
+ ComputationBuilder b(client_, TestName());
+ auto computation = b.ConstantR0<int32>(42);
+ EXPECT_TRUE(IsConstant(computation, &b));
+
+ auto value = ComputeConstantScalar<int32>(computation, &b);
+ ASSERT_TRUE(value.ok()) << value.status();
+ EXPECT_EQ(value.ValueOrDie(), 42);
+}
+
+TEST_F(ComputeConstantTest, ScalarFloatAdd) {
+ ComputationBuilder b(client_, TestName());
+ auto computation =
+ b.Add(b.ConstantR0<float>(42.5f), b.ConstantR0<float>(1.5f));
+ EXPECT_TRUE(IsConstant(computation, &b));
+
+ auto value = ComputeConstantScalar<float>(computation, &b);
+ ASSERT_TRUE(value.ok()) << value.status();
+ EXPECT_EQ(value.ValueOrDie(), 44.0f);
+}
+
+TEST_F(ComputeConstantTest, ScalarRng) {
+ ComputationBuilder b(client_, TestName());
+ auto computation =
+ b.RngUniform(b.ConstantR0<float>(1.1f), b.ConstantR0<float>(2.1f),
+ ShapeUtil::MakeShape(F32, {}));
+ EXPECT_FALSE(IsConstant(computation, &b));
+
+ auto value = ComputeConstantScalar<float>(computation, &b);
+ ASSERT_FALSE(value.ok())
+ << "computing a RNG value should not be considered a constant";
+}
+
+TEST_F(ComputeConstantTest, DirectParam) {
+ ComputationBuilder b(client_, TestName());
+ auto computation = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param");
+ EXPECT_FALSE(IsConstant(computation, &b));
+
+ auto value = ComputeConstantScalar<float>(computation, &b);
+ EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString())
+ .contains("depends on parameter"))
+ << value.status();
+}
+
+TEST_F(ComputeConstantTest, IndirectParam) {
+ ComputationBuilder b(client_, TestName());
+ auto computation =
+ b.Add(b.ConstantR0<float>(1.0f),
+ b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"));
+ EXPECT_FALSE(IsConstant(computation, &b));
+
+ auto value = ComputeConstantScalar<float>(computation, &b);
+ EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString())
+ .contains("depends on parameter"))
+ << value.status();
+}
+
+// Test computation of an expression interspersed with param nodes but
+// the expression does not depend on the param nodes.
+TEST_F(ComputeConstantTest, UnrelatedParam) {
+ ComputationBuilder b(client_, TestName());
+
+ auto param_a = b.Parameter(10, ShapeUtil::MakeShape(F32, {}), "param0");
+ auto constant_4 = b.Add(b.ConstantR0<float>(2.5f), b.ConstantR0<float>(1.5f));
+ auto not_constant_a = b.Add(constant_4, param_a);
+
+ auto param_b = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "param1");
+ auto constant_9 = b.Mul(b.ConstantR0<float>(2.0f), b.ConstantR0<float>(4.5f));
+ auto not_constant_b = b.Add(param_b, constant_9);
+
+ auto constant_13 = b.Add(constant_4, constant_9);
+ b.Add(not_constant_b, b.Add(constant_13, not_constant_a));
+
+ EXPECT_TRUE(IsConstant(constant_13, &b));
+
+ auto value = ComputeConstantScalar<float>(constant_13, &b);
+ ASSERT_TRUE(value.ok()) << value.status();
+ EXPECT_EQ(value.ValueOrDie(), 13.0f);
+}
+
+TEST_F(ComputeConstantTest, NonScalarAdd) {
+ ComputationBuilder b(client_, TestName());
+
+ auto computation =
+ b.Add(b.ConstantR1<int32>({1, 2}), b.ConstantR1<int32>({3, 4}));
+ EXPECT_TRUE(IsConstant(computation, &b));
+
+ auto computed = ComputeConstantLiteral(computation, &b);
+ ASSERT_TRUE(computed.ok()) << computed.status();
+ std::unique_ptr<Literal> expected_literal =
+ LiteralUtil::CreateR1<int32>({4, 6});
+ LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
+}
+
+TEST_F(ComputeConstantTest, IntegerDivide) {
+ ComputationBuilder b(client_, TestName());
+ auto computation = b.Div(b.ConstantR0<int32>(15), b.ConstantR0<int32>(3));
+ EXPECT_TRUE(IsConstant(computation, &b));
+
+ auto computed = ComputeConstantLiteral(computation, &b);
+ ASSERT_TRUE(computed.ok()) << computed.status();
+ std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR0<int32>(5);
+ LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
+}
+
+XLA_TEST_F(ComputeConstantTest, Layout) {
+ ComputationBuilder b(client_, TestName());
+
+ std::vector<std::vector<int64>> layouts = {{0, 1}, {1, 0}};
+ for (const std::vector<int64>& layout : layouts) {
+ auto layout_proto = LayoutUtil::MakeLayout(layout);
+ auto computed =
+ ComputeConstantLiteral(b.Add(b.ConstantR2<int32>({{1, 2}, {3, 4}}),
+ b.ConstantR2<int32>({{10, 20}, {30, 40}})),
+ &b, &layout_proto);
+ ASSERT_TRUE(computed.ok()) << computed.status();
+
+ std::unique_ptr<Literal> expected_literal =
+ test_utils::CreateR2LiteralWithLayout<int32>({{11, 22}, {33, 44}},
+ layout);
+ LiteralTestUtil::AssertEqualShapesAndLayouts(
+ expected_literal->shape(), computed.ValueOrDie()->shape());
+ LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
+ }
+}
+
+// This test is permanently disabled on CPU because it requires that the
+// backend used for execution is different than the backend used for
+// ComputeConstant which is always cpu.
+TEST_F(ComputeConstantTest, DISABLED_ON_CPU(ReuseComputedConstant)) {
+ // Compute a trivial constant, then try to use the value in an Execute
+ // call. This should fail because the constant resides on the CPU and the
+ // Execute call is executed on a different backend.
+ ComputationBuilder constant_b(client_, TestName());
+ auto constant = constant_b.ConstantR0<int32>(42);
+ auto handle = constant_b.ComputeConstant(constant).ConsumeValueOrDie();
+ auto literal = client_->Transfer(*handle).ConsumeValueOrDie();
+ LiteralTestUtil::ExpectR0Equal(42, *literal);
+
+ // Build trivial computation which takes one parameter.
+ ComputationBuilder b(client_, TestName());
+ b.Neg(b.Parameter(0, ShapeUtil::MakeShape(S32, {}), "param0"));
+ auto computation = b.Build().ConsumeValueOrDie();
+
+ // Try to use value from ComputeConstant in Execute.
+ auto execute_status = client_->Execute(computation, {handle.get()});
+ EXPECT_FALSE(execute_status.ok());
+ EXPECT_MATCH(
+ execute_status.status().error_message(),
+ testing::ContainsRegex("argument 0 is on device Host:0 but computation "
+ "will be executed on device"));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc
new file mode 100644
index 0000000000..9a48b19b96
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/concat_test.cc
@@ -0,0 +1,523 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+using ConcatTest = ClientLibraryTestBase;
+
+// Concatenate expects at least one argument.
+XLA_TEST_F(ConcatTest, Concat_Nothing) {
+ ComputationBuilder builder(client_, TestName());
+ auto concatenated = builder.ConcatInDim({}, 0);
+ StatusOr<Computation> computation_status = builder.Build();
+ ASSERT_FALSE(computation_status.ok());
+ EXPECT_MATCH(
+ computation_status.status().ToString(),
+ testing::ContainsRegex("Concatenate expects at least one argument"));
+}
+
+// Concatenate with one argument works.
+XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({42.0, 64.0});
+ auto concatenated = builder.ConcatInDim({a}, 0);
+
+ std::vector<float> expected = {42, 64};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+// Show that we can't concatenate R0 with R0 because we can't name the dimension
+// to concatenate on.
+XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR0<float>(42.0);
+ auto b = builder.ConstantR0<float>(64.0);
+ auto concatenated = builder.ConcatInDim({a, b}, 0);
+ StatusOr<Computation> computation_status = builder.Build();
+ ASSERT_FALSE(computation_status.ok());
+ EXPECT_MATCH(computation_status.status().ToString(),
+ testing::ContainsRegex(
+ "dimension to concatenate along out of bounds: 0"));
+}
+
+XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({});
+ auto b = builder.ConstantR1<float>({});
+ auto concatenated = builder.ConcatInDim({a, b}, 0);
+
+ std::vector<float> expected = {};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({});
+ auto b = builder.ConstantR1<float>({256.0});
+ auto concatenated = builder.ConcatInDim({a, b}, 0);
+
+ std::vector<float> expected = {256};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L0) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({42.0, 64.0});
+ auto b = builder.ConstantR1<float>({});
+ auto concatenated = builder.ConcatInDim({a, b}, 0);
+
+ std::vector<float> expected = {42, 64};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({42.0, 64.0});
+ auto b = builder.ConstantR1<float>({256.0});
+ auto concatenated = builder.ConcatInDim({a, b}, 0);
+
+ std::vector<float> expected = {42, 64, 256};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_R1_L253_With_R1_L7) {
+ std::vector<float> lhs(253);
+ std::vector<float> rhs(7);
+ std::vector<float> expected(253 + 7);
+ for (int i = 0; i < 253; ++i) {
+ expected[i] = lhs[i] = i + 1;
+ }
+ for (int i = 0; i < 7; ++i) {
+ expected[253 + i] = rhs[i] = 253 + i + 1;
+ }
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>(lhs);
+ auto b = builder.ConstantR1<float>(rhs);
+ auto concatenated = builder.ConcatInDim({a, b}, 0);
+
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) {
+ for (int dim : {0, 1}) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2FromArray2D(Array2D<float>(0, 0));
+ auto b = builder.ConstantR2FromArray2D(Array2D<float>(0, 0));
+ auto concatenated = builder.ConcatInDim({a, b}, dim);
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {},
+ ErrorSpec(0.0001));
+ }
+}
+
+XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim0) {
+ ComputationBuilder builder(client_, TestName());
+ auto a_array = CreatePatternedMatrix(1, 1);
+ auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0);
+ auto a = builder.ConstantR2FromArray2D(*a_array);
+ auto b = builder.ConstantR2FromArray2D(*b_array);
+ auto concatenated = builder.ConcatInDim({a, b}, 0);
+
+ Array2D<float> expected({
+ {0}, {64},
+ });
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a_array = CreatePatternedMatrix(1, 1);
+ auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0);
+ auto a = builder.ConstantR2FromArray2D(*a_array);
+ auto b = builder.ConstantR2FromArray2D(*b_array);
+ auto concatenated = builder.ConcatInDim({a, b}, 1);
+
+ Array2D<float> expected({
+ {0, 64},
+ });
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat2x0With2x5) {
+ ComputationBuilder builder(client_, TestName());
+ auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0);
+ auto a = builder.ConstantR2FromArray2D(Array2D<float>(2, 0));
+ auto b = builder.ConstantR2FromArray2D(*b_array);
+ auto concatenated = builder.ConcatInDim({a, b}, 1);
+
+ ComputeAndCompareR2<float>(&builder, *b_array, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat2x3With2x5) {
+ ComputationBuilder builder(client_, TestName());
+ auto a_array = CreatePatternedMatrix(2, 3);
+ auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0);
+ auto a = builder.ConstantR2FromArray2D(*a_array);
+ auto b = builder.ConstantR2FromArray2D(*b_array);
+ auto concatenated = builder.ConcatInDim({a, b}, 1);
+
+ Array2D<float> expected({
+ {0, 1, 2, 64, 65, 66, 67, 68},
+ {1000, 1001, 1002, 1064, 1065, 1066, 1067, 1068},
+ });
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat3x2With0x2) {
+ ComputationBuilder builder(client_, TestName());
+ auto a_array = CreatePatternedMatrix(3, 2);
+ auto a = builder.ConstantR2FromArray2D(*a_array);
+ auto b = builder.ConstantR2FromArray2D(Array2D<float>(0, 2));
+ auto concatenated = builder.ConcatInDim({a, b}, 0);
+
+ ComputeAndCompareR2<float>(&builder, *a_array, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat3x2With5x2) {
+ ComputationBuilder builder(client_, TestName());
+ auto a_array = CreatePatternedMatrix(3, 2);
+ auto b_array = CreatePatternedMatrix(5, 2, /*offset=*/64.0);
+ auto a = builder.ConstantR2FromArray2D(*a_array);
+ auto b = builder.ConstantR2FromArray2D(*b_array);
+ auto concatenated = builder.ConcatInDim({a, b}, 0);
+
+ Array2D<float> expected({
+ {0, 1},
+ {1000, 1001},
+ {2000, 2001},
+ {64, 65},
+ {1064, 1065},
+ {2064, 2065},
+ {3064, 3065},
+ {4064, 4065},
+ });
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_R3_3x0x2_3x0x1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR3FromArray3D(Array3D<float>(3, 0, 2));
+ auto b = builder.ConstantR3FromArray3D(Array3D<float>(3, 0, 1));
+ auto concatenated = builder.ConcatInDim({a, b}, 2);
+ ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 3), {},
+ ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) {
+ ComputationBuilder builder(client_, TestName());
+ Array3D<float> a_array({
+ // 3x1x2
+ {{0, 1}},
+ {{2, 3}},
+ {{4, 5}},
+ });
+ Array3D<float> b_array({
+ // 3x1x1
+ {{6}},
+ {{7}},
+ {{8}},
+ });
+ auto a = builder.ConstantR3FromArray3D(a_array);
+ auto b = builder.ConstantR3FromArray3D(b_array);
+ auto concatenated = builder.ConcatInDim({a, b}, 2);
+
+ Array3D<float> expected({
+ {{0, 1, 6}}, {{2, 3, 7}}, {{4, 5, 8}},
+ });
+ ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_R1_1x1_1x1_1x1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({42.0});
+ auto b = builder.ConstantR1<float>({64.0});
+ auto c = builder.ConstantR1<float>({256.0});
+ auto concatenated = builder.ConcatInDim({a, b, c}, 0);
+
+ std::vector<float> expected = {42, 64, 256};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) {
+ ComputationBuilder builder(client_, TestName());
+ Array3D<float> a_array({
+ // 3x1x2
+ {{0, 1}},
+ {{4, 5}},
+ {{8, 9}},
+ });
+ Array3D<float> b_array({
+ // 3x1x1
+ {{2}},
+ {{6}},
+ {{10}},
+ });
+ Array3D<float> c_array({
+ // 3x1x1
+ {{3}},
+ {{7}},
+ {{11}},
+ });
+ auto a = builder.ConstantR3FromArray3D(a_array);
+ auto b = builder.ConstantR3FromArray3D(b_array);
+ auto c = builder.ConstantR3FromArray3D(c_array);
+ auto concatenated = builder.ConcatInDim({a, b, c}, 2);
+
+ Array3D<float> expected({
+ {{0, 1, 2, 3}}, {{4, 5, 6, 7}}, {{8, 9, 10, 11}},
+ });
+ ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, DoubleConcatLeftAssociative) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({42.0});
+ auto b = builder.ConstantR1<float>({64.0});
+ auto c = builder.ConstantR1<float>({256.0});
+ // concatenated = (a concat b) concat c
+ auto concatenated =
+ builder.ConcatInDim({builder.ConcatInDim({a, b}, 0), c}, 0);
+
+ std::vector<float> expected = {42, 64, 256};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, DoubleConcatRightAssociative) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({42.0});
+ auto b = builder.ConstantR1<float>({64.0});
+ auto c = builder.ConstantR1<float>({256.0});
+ // concatenated = a concat (b concat c)
+ auto concatenated =
+ builder.ConcatInDim({a, builder.ConcatInDim({b, c}, 0)}, 0);
+
+ std::vector<float> expected = {42, 64, 256};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim0) {
+ Array2D<float> lhs(1, 1024);
+ Array2D<float> rhs(1, 1024);
+ for (int i = 0; i < 1024; ++i) {
+ lhs(0, i) = i;
+ rhs(0, i) = i + 1024;
+ }
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2FromArray2D<float>(lhs);
+ auto b = builder.ConstantR2FromArray2D<float>(rhs);
+ builder.ConcatInDim({a, b}, 0);
+
+ Array2D<float> expected(2, 1024);
+ for (int i = 0; i < 1024; ++i) {
+ expected(0, i) = i;
+ expected(1, i) = i + 1024;
+ }
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim1) {
+ Array2D<float> lhs(1, 1024);
+ Array2D<float> rhs(1, 1024);
+ for (int i = 0; i < 1024; ++i) {
+ lhs(0, i) = i;
+ rhs(0, i) = i + 1024;
+ }
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2FromArray2D<float>(lhs);
+ auto b = builder.ConstantR2FromArray2D<float>(rhs);
+ builder.ConcatInDim({a, b}, 1);
+
+ Array2D<float> expected(1, 2048);
+ for (int i = 0; i < 1024; ++i) {
+ expected(0, i) = i;
+ expected(0, i + 1024) = i + 1024;
+ }
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) {
+ Array2D<float> lhs(64, 64);
+ Array2D<float> rhs(64, 2);
+ for (int i0 = 0; i0 < 64; ++i0) {
+ for (int i1 = 0; i1 < 64; ++i1) {
+ lhs(i0, i1) = (i0 << 10) | i1;
+ }
+ for (int i1 = 0; i1 < 2; ++i1) {
+ rhs(i0, i1) = (i0 << 10) | (i1 + 64);
+ }
+ }
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2FromArray2D<float>(lhs);
+ auto b = builder.ConstantR2FromArray2D<float>(rhs);
+ builder.ConcatInDim({a, b}, 1);
+
+ Array2D<float> expected(64, 66);
+ for (int i0 = 0; i0 < 64; ++i0) {
+ for (int i1 = 0; i1 < 66; ++i1) {
+ expected(i0, i1) = (i0 << 10) | i1;
+ }
+ }
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+// Show that we can't concatenate with an opaques.
+XLA_TEST_F(ConcatTest, CannotConcatOpaques) {
+ ComputationBuilder builder(client_, TestName());
+ auto opaque_shape = ShapeUtil::MakeOpaqueShape();
+ auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1});
+ auto x = builder.Parameter(0, r1f32, "x");
+ auto y = builder.Parameter(1, opaque_shape, "y");
+ auto concatenated = builder.ConcatInDim({x, y}, 0);
+ StatusOr<Computation> computation_status = builder.Build();
+ ASSERT_FALSE(computation_status.ok());
+ EXPECT_MATCH(
+ computation_status.status().ToString(),
+ testing::ContainsRegex(
+ "Expected non-opaque argument for operand of concatenation"));
+}
+
+XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) {
+ ComputationBuilder builder(client_, TestName());
+ auto p0 = builder.ConstantR1<bool>({true});
+ auto p1 = builder.ConstantR1<bool>({false});
+ auto p2 = builder.ConstantR1<bool>({true});
+ auto concatenated = builder.ConcatInDim({p0, p1, p2}, 0);
+
+ bool expected[] = {true, false, true};
+ ComputeAndCompareR1<bool>(&builder, expected, {});
+}
+
+XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a0 = builder.ConstantR1<int32>({1});
+ auto a1 = builder.ConstantR1<int32>({2, 3});
+ auto a2 = builder.ConstantR1<int32>({4, 5, 6});
+ auto a3 = builder.ConstantR1<int32>({7, 8, 9, 10});
+ auto concatenated = builder.ConcatInDim({a0, a1, a2, a3}, 0);
+
+ std::vector<int32> expected(10);
+ std::iota(expected.begin(), expected.end(), 1);
+ ComputeAndCompareR1<int32>(&builder, expected, {});
+}
+
+// Describes a binary rank-2 concatenation test.
+struct R2BinarySpec {
+ int64 lhs_dim0;
+ int64 lhs_dim1;
+ int64 rhs_dim0;
+ int64 rhs_dim1;
+ int64 concat_dimension;
+};
+
+// TEST_P harness for binary rank-2 concatenation.
+class ConcatR2BinaryTest : public ClientLibraryTestBase,
+ public ::testing::WithParamInterface<R2BinarySpec> {
+};
+
+TEST_P(ConcatR2BinaryTest, DoIt) {
+ const R2BinarySpec& spec = GetParam();
+ Array2D<int32> lhs(spec.lhs_dim0, spec.lhs_dim1);
+ lhs.FillUnique();
+ Array2D<int32> rhs(spec.rhs_dim0, spec.rhs_dim1);
+ rhs.FillUnique(1000);
+
+ ComputationBuilder builder(client_, TestName());
+ auto a0 = builder.ConstantR2FromArray2D<int32>(lhs);
+ auto a1 = builder.ConstantR2FromArray2D<int32>(rhs);
+ builder.ConcatInDim({a0, a1}, spec.concat_dimension);
+
+ std::unique_ptr<Array2D<int32>> expected =
+ ReferenceUtil::Concat2D(lhs, rhs, spec.concat_dimension);
+ ComputeAndCompareR2<int32>(&builder, *expected, {});
+}
+
+// Regression test for b/31944287. x*y is used (at the same index) by all
+// operands of the concat. We should emit x*y in three incoming basic blocks of
+// the concat because these basic blocks are not control-equivalent.
+//
+// x*y
+// / | \
+// add1 add2 add3
+// \ | /
+// concat
+XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) {
+ auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
+ auto x_literal = LiteralUtil::CreateR0<float>(2.f);
+ auto y_literal = LiteralUtil::CreateR0<float>(3.f);
+ auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
+ auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.Parameter(0, f32_scalar, "x");
+ auto y = builder.Parameter(1, f32_scalar, "y");
+ auto mul = builder.Mul(x, y);
+ auto add1 = builder.Add(mul, builder.ConstantR1<float>({1.f, 2.f}));
+ auto add2 = builder.Add(mul, builder.ConstantR1<float>({3.f, 4.f}));
+ auto add3 = builder.Add(mul, builder.ConstantR1<float>({5.f, 6.f}));
+ builder.ConcatInDim({add1, add2, add3}, /*dimension=*/0);
+
+ ComputeAndCompareR1<float>(&builder, {7., 8., 9., 10., 11., 12.},
+ {x_data.get(), y_data.get()}, ErrorSpec(1e-4));
+}
+
+INSTANTIATE_TEST_CASE_P(ConcatR2BinaryTestInstantiation, ConcatR2BinaryTest,
+ ::testing::Values(R2BinarySpec{1, 1, 1, 1, 0},
+ R2BinarySpec{1, 1, 1, 1, 1},
+ R2BinarySpec{4, 3, 4, 3, 0},
+ R2BinarySpec{4, 3, 4, 3, 1},
+ R2BinarySpec{7, 128, 1, 128, 0},
+ R2BinarySpec{8, 127, 8, 1, 1}));
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc
new file mode 100644
index 0000000000..58d52ac116
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/constants_test.cc
@@ -0,0 +1,193 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests that constants in program memory round trip as expected.
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ConstantsTest : public ClientLibraryTestBase {
+ protected:
+ const ErrorSpec error_spec_{1e-3, 1e-5};
+};
+
+TEST_F(ConstantsTest, ZeroCellF32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR1<float>({});
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+TEST_F(ConstantsTest, OneCellF32) {
+ std::vector<float> constant = {2.0};
+
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR1<float>(constant);
+
+ ComputeAndCompareR1<float>(&builder, constant, {}, error_spec_);
+}
+
+TEST_F(ConstantsTest, OneCellS32) {
+ std::vector<int32> constant = {2};
+
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR1<int32>(constant);
+
+ ComputeAndCompareR1<int32>(&builder, constant, {});
+}
+
+TEST_F(ConstantsTest, OneCellU32) {
+ std::vector<uint32> constant = {2};
+
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR1<uint32>(constant);
+
+ ComputeAndCompareR1<uint32>(&builder, constant, {});
+}
+
+TEST_F(ConstantsTest, EightCells) {
+ std::vector<float> constant = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};
+
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR1<float>(constant);
+
+ ComputeAndCompareR1<float>(&builder, constant, {}, error_spec_);
+}
+
+TEST_F(ConstantsTest, SixteenCells) {
+ std::vector<float> constant = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
+ 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0};
+
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR1<float>(constant);
+
+ ComputeAndCompareR1<float>(&builder, constant, {}, error_spec_);
+}
+
+TEST_F(ConstantsTest, Empty_0x2) {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 2), {}, error_spec_);
+}
+
+TEST_F(ConstantsTest, Small_2x2) {
+ std::unique_ptr<Array2D<float>> constant =
+ MakeLinspaceArray2D(100.0, 200.0, 2, 2);
+
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR2FromArray2D<float>(*constant);
+
+ ComputeAndCompareR2<float>(&builder, *constant, {}, error_spec_);
+}
+
+TEST_F(ConstantsTest, Empty_3x0x2) {
+ ComputationBuilder builder(client_, TestName());
+ auto constant = builder.ConstantLiteral(
+ *LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(3, 0, 2)));
+
+ ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 2), {});
+}
+
+TEST_F(ConstantsTest, Small_2x2x2) {
+ ComputationBuilder builder(client_, TestName());
+ Array3D<float> array3d({
+ // x0 x1
+ {{1.f, 2.f}, // y0
+ {3.f, 4.f}}, // y1
+
+ {{5.f, 6.f}, // y0
+ {7.f, 8.f}}, // y1
+ });
+ auto constant = builder.ConstantLiteral(
+ *LiteralUtil::CreateR3FromArray3D<float>(array3d));
+
+ ComputeAndCompareR3<float>(&builder, array3d, {});
+}
+
+TEST_F(ConstantsTest, Small_3x2x1x1) {
+ Array4D<float> input_array(3, 2, 1, 1);
+ Array2D<float> pz({
+ // z0 z1
+ {-1.0f, 4.1f}, // p0
+ {2.0f, 4.1f}, // p1
+ {5.0f, 4.4f}, // p2
+ });
+ input_array.FillWithPZ(pz);
+ Literal input_literal = *LiteralUtil::CreateR4FromArray4D(input_array);
+
+ {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantLiteral(input_literal);
+ ComputeAndCompareR4<float>(&builder, input_array, {}, error_spec_);
+ }
+
+ {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR4FromArray4D<float>(input_array);
+ ComputeAndCompareR4<float>(&builder, input_array, {}, error_spec_);
+ }
+}
+
+// TODO(b/29263943): Support tuple constants.
+TEST_F(ConstantsTest, DISABLED_TupleConstant) {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantLiteral(*LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
+ LiteralUtil::CreateR1<float>({2.0, 42}).get()}));
+
+ std::unique_ptr<Literal> result = ExecuteAndTransferOrDie(&builder, {});
+
+ LiteralTestUtil::ExpectR2Near<float>({{1.0}, {2.0}},
+ result->tuple_literals(0), error_spec_);
+ LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, result->tuple_literals(1),
+ error_spec_);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc
new file mode 100644
index 0000000000..9f8c3a9aeb
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/convert_test.cc
@@ -0,0 +1,210 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ConvertTest : public ClientLibraryTestBase {
+ public:
+ explicit ConvertTest(perftools::gputools::Platform* platform = nullptr)
+ : ClientLibraryTestBase(platform,
+ /*disabled_pass_names=*/{"algsimp", "inline"}) {}
+};
+
+TEST_F(ConvertTest, ConvertR1S32ToR1S32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>({42, 64});
+ builder.ConvertElementType(a, S32);
+
+ std::vector<int32> expected = {42, 64};
+ ComputeAndCompareR1<int32>(&builder, expected, {});
+}
+
+TEST_F(ConvertTest, ConvertR1F32ToR1F32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({42.0f, 64.0f});
+ builder.ConvertElementType(a, F32);
+
+ std::vector<float> expected = {42.0f, 64.0f};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ConvertTest, ConvertR1S32ToR1F32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>({42, 64});
+ builder.ConvertElementType(a, F32);
+
+ std::vector<float> expected = {42.0f, 64.0f};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>({});
+ builder.ConvertElementType(a, F32);
+
+ std::vector<float> expected = {};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ConvertTest, ConvertR1F32ToR1S32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({42.6, 64.4});
+ builder.ConvertElementType(a, S32);
+
+ std::vector<int32> expected = {42, 64};
+ ComputeAndCompareR1<int32>(&builder, expected, {});
+}
+
+XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int64>({32, 64});
+ builder.ConvertElementType(a, F32);
+
+ std::vector<float> expected = {32.0, 64.0};
+ ComputeAndCompareR1<float>(&builder, expected, {});
+}
+
+XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<uint8_t>({32, 64});
+ builder.ConvertElementType(a, F32);
+
+ std::vector<float> expected = {32.0, 64.0};
+ ComputeAndCompareR1<float>(&builder, expected, {});
+}
+
+XLA_TEST_F(ConvertTest, ConvertR1U8ToR1S32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<uint8_t>({32, 64});
+ builder.ConvertElementType(a, S32);
+
+ std::vector<int32_t> expected = {32, 64};
+ ComputeAndCompareR1<int32_t>(&builder, expected, {});
+}
+
+XLA_TEST_F(ConvertTest, ConvertR1U8ToR1U32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<uint8_t>({32, 64});
+ builder.ConvertElementType(a, U32);
+
+ std::vector<uint32_t> expected = {32, 64};
+ ComputeAndCompareR1<uint32_t>(&builder, expected, {});
+}
+
+XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F64) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({32.0f, 64.0f});
+ builder.ConvertElementType(a, F64);
+
+ std::vector<double> expected = {32.0, 64.0};
+ ComputeAndCompareR1<double>(&builder, expected, {});
+}
+
+XLA_TEST_F(ConvertTest, ConvertR1F64ToR1F32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<double>({32.0, 64.0});
+ builder.ConvertElementType(a, F32);
+
+ std::vector<float> expected = {32.0f, 64.0f};
+ ComputeAndCompareR1<float>(&builder, expected, {});
+}
+
+TEST_F(ConvertTest, ConvertS32Extremes) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>(
+ {std::numeric_limits<int32>::min(), std::numeric_limits<int32>::max()});
+ builder.ConvertElementType(a, F32);
+
+ std::vector<float> expected = {
+ static_cast<float>(std::numeric_limits<int32>::min()),
+ static_cast<float>(std::numeric_limits<int32>::max())};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ConvertTest, ConvertMapToS32) {
+ ComputationBuilder builder(client_, TestName());
+ auto b = builder.CreateSubBuilder("convert");
+ auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in");
+ b->ConvertElementType(param, S32);
+ auto a = builder.ConstantR1<float>({42.0f, 64.0f});
+ builder.Map({a}, b->BuildAndNoteError());
+
+ std::vector<int32> expected = {42, 64};
+ ComputeAndCompareR1<int32>(&builder, expected, {});
+}
+
+TEST_F(ConvertTest, ConvertMapToF32) {
+ ComputationBuilder builder(client_, TestName());
+ auto b = builder.CreateSubBuilder("convert");
+ auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in");
+ b->ConvertElementType(param, F32);
+ auto a = builder.ConstantR1<int32>({42, 64});
+ builder.Map({a}, b->BuildAndNoteError());
+
+ std::vector<float> expected = {42.0f, 64.0f};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+// Regression test for b/31758660. When ReshapeMover transforms
+// input -> reshape -> convert
+// to
+// input -> convert -> reshape
+// the new convert should have the same element type as the old convert.
+TEST_F(ConvertTest, ConvertReshape) {
+ ComputationBuilder builder(client_, TestName());
+ auto input = builder.ConstantR1<int32>({42});
+ auto reshape = builder.Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{});
+ builder.ConvertElementType(reshape, F32);
+
+ ComputeAndCompareR0<float>(&builder, 42.0f, {}, ErrorSpec(0.0001));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
new file mode 100644
index 0000000000..9f38dc4b36
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
@@ -0,0 +1,117 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <array>
+#include <memory>
+
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/padding.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ConvolutionDimensionNumbersTest : public ClientLibraryTestBase {};
+
+// Tests the convolution operation with invalid input dimension numbers.
+TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) {
+ auto dimension_numbers_status =
+ ComputationBuilder::CreateConvDimensionNumbers(0, 2, 2, 3, 0, 1, 2, 3);
+ ASSERT_FALSE(dimension_numbers_status.ok());
+ ASSERT_MATCH(dimension_numbers_status.status().error_message(),
+ testing::ContainsRegex("input are not unique"));
+}
+
+// Tests the convolution operation with invalid weight dimension numbers.
+TEST_F(ConvolutionDimensionNumbersTest, InvalidWeightDimensionNumbers) {
+ auto dimension_numbers_status =
+ ComputationBuilder::CreateConvDimensionNumbers(0, 1, 2, 3, 2, 3, 2, 3);
+ ASSERT_FALSE(dimension_numbers_status.ok());
+ ASSERT_MATCH(dimension_numbers_status.status().error_message(),
+ testing::ContainsRegex("weight are not unique"));
+}
+
+XLA_TEST_F(ConvolutionDimensionNumbersTest,
+ TwoConvsWithDifferentDimensionNumbers) {
+ auto input_array = MakeUnique<Array4D<float>>(2, 3, 5, 5);
+ input_array->FillWithMultiples(0.1);
+ auto weight_array = MakeUnique<Array4D<float>>(4, 3, 1, 1);
+ weight_array->FillWithMultiples(0.2);
+ auto weight_data =
+ client_
+ ->TransferToServer(*LiteralUtil::CreateR4FromArray4D(*weight_array))
+ .ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto input = builder.ConstantR4FromArray4D<float>(*input_array);
+ auto weight =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {4, 3, 1, 1}), "weight");
+ auto conv1 = builder.Conv(input, weight, {1, 1}, Padding::kValid);
+
+ ConvolutionDimensionNumbers dim_nums =
+ ComputationBuilder::CreateDefaultConvDimensionNumbers();
+ // Swap batch_dimension and feature_dimension.
+ int64 tmp = dim_nums.batch_dimension();
+ dim_nums.set_batch_dimension(dim_nums.feature_dimension());
+ dim_nums.set_feature_dimension(tmp);
+ // Swap kernel_input_feature_dimension and kernel_output_feature_dimension.
+ tmp = dim_nums.kernel_input_feature_dimension();
+ dim_nums.set_kernel_input_feature_dimension(
+ dim_nums.kernel_output_feature_dimension());
+ dim_nums.set_kernel_output_feature_dimension(tmp);
+ builder.ConvWithGeneralDimensions(input, conv1, {1, 1}, Padding::kValid,
+ dim_nums);
+
+ auto expected_conv1 = ReferenceUtil::ConvArray4D(*input_array, *weight_array,
+ {1, 1}, Padding::kValid);
+ auto expected_conv2 = ReferenceUtil::ConvArray4DGeneralDimensions(
+ *input_array, *expected_conv1, {1, 1}, Padding::kValid, dim_nums);
+
+ ComputeAndCompareR4<float>(&builder, *expected_conv2, {weight_data.get()},
+ ErrorSpec(0.001, 0.01));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
new file mode 100644
index 0000000000..ffbda89b94
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -0,0 +1,361 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests of convolution with trivial kernels and no special variations (like
+// strides and padding).
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/padding.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ConvolutionTest : public ClientLibraryTestBase {
+ protected:
+#if XLA_TEST_BACKEND_GPU
+ // XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial
+ // convolution. So relax the absolute error threshold.
+ ErrorSpec error_spec_ = ErrorSpec(1e-3);
+#else
+ ErrorSpec error_spec_ = ErrorSpec(1e-4);
+#endif
+};
+
+XLA_TEST_F(ConvolutionTest, ForwardPassConvolution_3x3x256_256_OutputZ_Iota) {
+ const int kInputActivationSizeY = 3;
+ const int kInputActivationSizeX = 3;
+ const int kInputActivationSizeZ = 256;
+ const int kKernelSizeX = 2;
+ const int kKernelSizeY = 2;
+ const int kOutputActivationSizeZ = 256;
+ const int kMiniBatchSize = 4;
+ auto alhs =
+ MakeUnique<Array4D<float>>(kMiniBatchSize, kInputActivationSizeZ,
+ kInputActivationSizeY, kInputActivationSizeX);
+ alhs->FillWithMultiples(1.0f);
+ ASSERT_EQ(3, alhs->width());
+ ASSERT_EQ(3, alhs->height());
+
+ auto arhs =
+ MakeUnique<Array4D<float>>(kOutputActivationSizeZ, kInputActivationSizeZ,
+ kKernelSizeY, kKernelSizeX);
+ Array2D<float> rhs_raster({
+ {1.0f, 0.0f}, // row 0
+ {0.0f, 0.0f}, // row 1
+ });
+ arhs->FillWithYX(rhs_raster);
+ ASSERT_EQ(2, arhs->width());
+ ASSERT_EQ(2, arhs->height());
+
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR4FromArray4D<float>(*alhs);
+ auto rhs = builder.ConstantR4FromArray4D<float>(*arhs);
+ builder.Conv(lhs, rhs, {1, 1}, Padding::kValid);
+
+ std::unique_ptr<Array4D<float>> aexpected =
+ ReferenceUtil::ConvArray4D(*alhs, *arhs, {1, 1}, Padding::kValid);
+
+ ComputeAndCompareR4<float>(&builder, *aexpected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionTest, Convolve_1x1x1x2_1x1x1x2_Valid) {
+ ComputationBuilder builder(client_, TestName());
+ {
+ Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
+ Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ }
+
+ Array4D<float> input(1, 1, 1, 2);
+ input.FillWithYX(Array2D<float>({
+ {1, 2},
+ }));
+ Array4D<float> filter(1, 1, 1, 2);
+ filter.FillWithYX(Array2D<float>({
+ {5, 6},
+ }));
+
+ std::unique_ptr<Array4D<float>> aexpected =
+ ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid);
+
+ auto input_literal =
+ client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(input))
+ .ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(filter))
+ .ConsumeValueOrDie();
+
+ ComputeAndCompareR4<float>(&builder, *aexpected,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
+}
+
+// Tests valid padding for 2D convolution in raster space.
+TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Valid) {
+ ComputationBuilder builder(client_, TestName());
+ {
+ Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
+ Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ }
+
+ Array4D<float> input(1, 1, 4, 4);
+ // clang-format off
+ input.FillWithYX(Array2D<float>({
+ {1, 2, 3, 4 },
+ {5, 6, 7, 8 },
+ {9, 10, 11, 12},
+ {13, 14, 15, 16},
+ }));
+ // clang-format on
+ Array4D<float> filter(1, 1, 2, 2);
+ // clang-format off
+ filter.FillWithYX(Array2D<float>({
+ {5, 6},
+ {7, 8},
+ }));
+ // clang-format on
+
+ std::unique_ptr<Array4D<float>> aexpected =
+ ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid);
+
+ auto input_literal =
+ client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(input))
+ .ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(filter))
+ .ConsumeValueOrDie();
+
+ ComputeAndCompareR4<float>(&builder, *aexpected,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
+}
+
+// Tests same padding for 2D convolution in raster space.
+TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Same) {
+ ComputationBuilder builder(client_, TestName());
+ {
+ Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
+ Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ builder.Conv(input, filter, {1, 1}, Padding::kSame);
+ }
+
+ Array4D<float> input(1, 1, 4, 4);
+ // clang-format off
+ input.FillWithYX(Array2D<float>({
+ {1, 2, 3, 4 },
+ {5, 6, 7, 8 },
+ {9, 10, 11, 12},
+ {13, 14, 15, 16},
+ }));
+ // clang-format on
+ Array4D<float> filter(1, 1, 2, 2);
+ // clang-format off
+ filter.FillWithYX(Array2D<float>({
+ {5, 6},
+ {7, 8},
+ }));
+ // clang-format on
+
+ std::unique_ptr<Array4D<float>> aexpected =
+ ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame);
+
+ auto input_literal =
+ client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(input))
+ .ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(filter))
+ .ConsumeValueOrDie();
+
+ ComputeAndCompareR4<float>(&builder, *aexpected,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
+}
+
+// Tests same padding for 2D convolution in raster space with an odd sized
+// kernel.
+TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x3x3_Same) {
+ ComputationBuilder builder(client_, TestName());
+ {
+ Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
+ Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ builder.Conv(input, filter, {1, 1}, Padding::kSame);
+ }
+
+ Array4D<float> input(1, 1, 4, 4);
+ // clang-format off
+ input.FillWithYX(Array2D<float>({
+ {1, 2, 3, 4 },
+ {5, 6, 7, 8 },
+ {9, 10, 11, 12},
+ {13, 14, 15, 16},
+ }));
+ // clang-format on
+ Array4D<float> filter(1, 1, 3, 3);
+ // clang-format off
+ filter.FillWithYX(Array2D<float>({
+ { 5, 6, 7},
+ { 8, 9, 10},
+ {11, 12, 13},
+ }));
+ // clang-format on
+
+ std::unique_ptr<Array4D<float>> aexpected =
+ ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame);
+
+ auto input_literal =
+ client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(input))
+ .ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(filter))
+ .ConsumeValueOrDie();
+
+ ComputeAndCompareR4<float>(&builder, *aexpected,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
+}
+
+// TODO(b/32873825): implement 1D convolution on GPU.
+XLA_TEST_F(ConvolutionTest, DISABLED_ON_GPU(Convolve1D_1x2x5_1x2x2_Valid)) {
+ ComputationBuilder builder(client_, TestName());
+ {
+ Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
+ Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ builder.Conv(input, filter, {1}, Padding::kValid);
+ }
+
+ Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
+ Array3D<float> filter({{{10, 20}, {30, 40}}});
+
+ Array3D<float> expected({{{510, 610, 710, 810}}});
+
+ auto input_literal =
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+ .ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+ .ConsumeValueOrDie();
+
+ ComputeAndCompareR3<float>(&builder, expected,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
+}
+
+// TODO(b/32873825): implement 3D convolution on GPU.
+XLA_TEST_F(ConvolutionTest,
+ DISABLED_ON_GPU(Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid)) {
+ ComputationBuilder builder(client_, TestName());
+ std::vector<int64> input_dims = {1, 4, 2, 3, 3};
+ std::vector<int64> filter_dims = {2, 2, 2, 3, 3};
+ Shape input_shape = ShapeUtil::MakeShape(F32, input_dims);
+ Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims);
+ {
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+
+ // Tensorflow dimension numbers for 3D convolution.
+ ConvolutionDimensionNumbers dnums;
+ dnums.set_batch_dimension(0);
+ dnums.add_spatial_dimensions(1);
+ dnums.add_spatial_dimensions(2);
+ dnums.add_spatial_dimensions(3);
+ dnums.set_feature_dimension(4);
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.add_kernel_spatial_dimensions(1);
+ dnums.add_kernel_spatial_dimensions(2);
+ dnums.set_kernel_input_feature_dimension(3);
+ dnums.set_kernel_output_feature_dimension(4);
+
+ builder.ConvWithGeneralDimensions(input, filter, {1, 1, 1}, Padding::kValid,
+ dnums);
+ }
+
+ std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
+ std::iota(input_elems.begin(), input_elems.end(), 1.0f);
+ auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
+ auto input_r5 =
+ LiteralUtil::Reshape(*input_r1, input_dims).ConsumeValueOrDie();
+
+ std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
+ std::iota(filter_elems.begin(), filter_elems.end(), 1.0f);
+ auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
+ auto filter_r5 =
+ LiteralUtil::Reshape(*filter_r1, filter_dims).ConsumeValueOrDie();
+
+ auto expected_r1 = LiteralUtil::CreateR1<float>(
+ {19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446,
+ 38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470});
+ auto expected_r5 =
+ LiteralUtil::Reshape(*expected_r1, {1, 3, 1, 2, 3}).ConsumeValueOrDie();
+
+ auto input_literal = client_->TransferToServer(*input_r5).ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(*filter_r5).ConsumeValueOrDie();
+
+ ComputeAndCompareLiteral(&builder, *expected_r5,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
new file mode 100644
index 0000000000..b599f9b95b
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
@@ -0,0 +1,1294 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests of convolution variants -- kernel sizes, padding, and strides --
+// in small sized data.
+
+#include <algorithm>
+#include <initializer_list>
+#include <memory>
+#include <numeric>
+#include <random>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/padding.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ConvolutionVariantsTest : public ClientLibraryTestBase {
+ protected:
+#if XLA_TEST_BACKEND_GPU
+ // XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial
+ // convolution. So relax the absolute error threshold.
+ ErrorSpec error_spec_ = ErrorSpec(1e-1, 1e-5);
+#else
+ ErrorSpec error_spec_ = ErrorSpec(1e-4, 1e-2);
+#endif
+};
+
+TEST_F(ConvolutionVariantsTest, Minimal) {
+ ComputationBuilder builder(client_, TestName());
+
+ const Array4D<float> input_array(1, 1, 1, 1, {2});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 1, {3});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ const Array4D<float> expected(1, 1, 1, 1, {6});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, MinimalWithBatch) {
+ ComputationBuilder builder(client_, TestName());
+
+ const Array4D<float> input_array(5, 1, 1, 1, {1, 2, 3, 4, 5});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 1, {2});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ const Array4D<float> expected(5, 1, 1, 1, {2, 4, 6, 8, 10});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Flat1x1) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(2, 1, 3, 4);
+ input_array.FillWithMultiples(1);
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 1, {2.3});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(2, 1, 3, 4);
+ expected.FillWithMultiples(2.3);
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Deep1x1) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 2, 1, 1, {10, 1});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(1, 3, 1, 1, {12, 34, 56});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x2in1x2) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 1, 2, {1, 2});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(1, 1, 1, 1, {12});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x2in1x3) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 1, 3, {1, 2, 3});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(1, 1, 1, 2, {12, 23});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x2in2x2) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(1, 1, 2, 1, {12, 34});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter2x1in2x2) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 2, 1, {10, 1});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(1, 1, 1, 2, {13, 24});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter2x2in2x2) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 2, 2, {1000, 100, 10, 1});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(1, 1, 1, 1, {1234});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x2in2x3WithDepthAndBatch) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(
+ 2, 2, 2, 3, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, // plane 0
+ 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 0, 0}); // plane 1
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(
+ 2, 2, 1, 2, {1000, 100, 10, 1, 0.1, 0.01, 0.001, 0.0001});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(
+ 2, 2, 2, 2,
+ {167, 1278, 3490, 4500, 0.0167, 0.1278, 0.3490, 0.4500, // plane 0
+ 334, 2556, 6980, 9000, 0.0334, 0.2556, 0.6980, 0.9000}); // plane 1
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x4) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 1, 4, {1, 2, 3, 4});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 1, {10});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 2}, Padding::kValid);
+
+ Array4D<float> expected(1, 1, 1, 2, {10, 30});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x5) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 1, 5, {1, 2, 3, 4, 5});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 1, {10});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 2}, Padding::kValid);
+
+ Array4D<float> expected(1, 1, 1, 3, {10, 30, 50});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x4) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 1, 4, {1, 2, 3, 4});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 3, {100, 10, 1});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 2}, Padding::kValid);
+
+ Array4D<float> expected(1, 1, 1, 1, {123});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x5) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 1, 5, {1, 2, 3, 4, 5});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 3, {100, 10, 1});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 2}, Padding::kValid);
+
+ Array4D<float> expected(1, 1, 1, 2, {123, 345});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x1stride2x2in3x3) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 1, {10});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {2, 2}, Padding::kValid);
+
+ Array4D<float> expected(1, 1, 2, 2, {10, 30, 70, 90});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter3x1in1x1Padded) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 1, 1, {1});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 3, {10, 20, 30});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kSame);
+
+ Array4D<float> expected(1, 1, 1, 1, {20});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter5x1in3x1Padded) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 1, 3, {1, 2, 3});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 5, {10000, 1000, 100, 10, 1});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kSame);
+
+ Array4D<float> expected(1, 1, 1, 3, {123, 1230, 12300});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 3, 3, {10000, 0, 1000, // row 0
+ 0, 100, 0, // row 1
+ 10, 0, 1}); // row 2
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kSame);
+
+ Array4D<float> expected(1, 1, 2, 2, {104, 230, 2300, 10400});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x1in2x1WithPaddingAndDepth) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 2, 1, 2, {1, 2, 3, 4});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 2, 1, 1, {10, 1});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kSame);
+
+ Array4D<float> expected(1, 1, 1, 2, {13, 24});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter2x2Stride1x1Input3x3) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 2, 2, {7, 13, 17, 23});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(1, 1, 2, 2, {216, 276, 396, 456});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x2Stride1x1Input1x3) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 1, 3, {1, 2, 3});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 2, {7, 13});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(1, 1, 1, 2, {33, 53});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter2x1x8x8Input1x1x8x8) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(64);
+ std::iota(input_data.begin(), input_data.end(), 0.0);
+ Array4D<float> input_array(1, 1, 8, 8, input_data);
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ std::vector<float> filter_data(128);
+ std::fill(filter_data.begin(), filter_data.begin() + 64, 1.0);
+ std::fill(filter_data.begin() + 64, filter_data.begin() + 128, 2.0);
+ const Array4D<float> filter_array(2, 1, 8, 8, filter_data);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(1, 2, 1, 1, {2016, 4032});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input16x1x1x1) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(16 * 1 * 1 * 1);
+ std::iota(input_data.begin(), input_data.end(), 1.0);
+ Array4D<float> input_array(16, 1, 1, 1, input_data);
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ std::vector<float> filter_data(1 * 1 * 1 * 1);
+ std::iota(filter_data.begin(), filter_data.end(), 1.0);
+ const Array4D<float> filter_array(1, 1, 1, 1, filter_data);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ std::vector<float> expected_data = {1, 2, 3, 4, 5, 6, 7, 8,
+ 9, 10, 11, 12, 13, 14, 15, 16};
+ Array4D<float> expected(16, 1, 1, 1, expected_data);
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input16x1x2x2) {
+ ComputationBuilder builder(client_, TestName());
+
+ constexpr int bs = 16;
+ constexpr int kx = 2;
+ constexpr int ky = 2;
+ Array4D<float> input_array(bs, 1, ky, kx);
+ for (int i0 = 0; i0 < bs; ++i0) {
+ for (int i2 = 0; i2 < ky; ++i2) {
+ for (int i3 = 0; i3 < kx; ++i3) {
+ input_array(i0, 0, i2, i3) = i0 + 1;
+ }
+ }
+ }
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ std::vector<float> filter_data(1 * 1 * ky * kx);
+ std::iota(filter_data.begin(), filter_data.end(), 1.0);
+ const Array4D<float> filter_array(1, 1, ky, kx, filter_data);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ std::vector<float> expected_data(bs);
+ for (int i = 0; i < bs; ++i) {
+ expected_data[i] = 10 * (i + 1);
+ }
+ Array4D<float> expected(bs, 1, 1, 1, expected_data);
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) {
+ ComputationBuilder builder(client_, TestName());
+
+ constexpr int kx = 2;
+ constexpr int ky = 2;
+ constexpr int bs = 3;
+ Array4D<float> input_array(bs, 1, ky, kx);
+ for (int i0 = 0; i0 < bs; ++i0) {
+ for (int i2 = 0; i2 < ky; ++i2) {
+ for (int i3 = 0; i3 < kx; ++i3) {
+ input_array(i0, 0, i2, i3) = i0 + i2 + i3 + 1;
+ }
+ }
+ }
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ std::vector<float> filter_data(1 * 1 * ky * kx);
+ std::iota(filter_data.begin(), filter_data.end(), 1.0);
+ const Array4D<float> filter_array(1, 1, ky, kx, filter_data);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ std::vector<float> expected_data = {
+ 23, 33, 43,
+ };
+ Array4D<float> expected(bs, 1, 1, 1, expected_data);
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x1x8x8Input16x1x8x8) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(16, 1, 8, 8);
+ for (int i0 = 0; i0 < 16; ++i0) {
+ for (int i2 = 0; i2 < 8; ++i2) {
+ for (int i3 = 0; i3 < 8; ++i3) {
+ input_array(i0, 0, i2, i3) = i0 + i2 + i3 + 1;
+ }
+ }
+ }
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ std::vector<float> filter_data(1 * 1 * 8 * 8);
+ std::iota(filter_data.begin(), filter_data.end(), 1.0);
+ const Array4D<float> filter_array(1, 1, 8, 8, filter_data);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ std::vector<float> expected_data = {
+ 19664, 21744, 23824, 25904, 27984, 30064, 32144, 34224,
+ 36304, 38384, 40464, 42544, 44624, 46704, 48784, 50864,
+ };
+ Array4D<float> expected(16, 1, 1, 1, expected_data);
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(2 * 8 * 8);
+ std::iota(input_data.begin(), input_data.end(), 0.0);
+ Array4D<float> input_array(1, 2, 8, 8, input_data);
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ std::vector<float> filter_data(2 * 2 * 8 * 8);
+ std::fill(filter_data.begin(), filter_data.begin() + filter_data.size() / 4,
+ 1.0);
+ std::fill(filter_data.begin() + filter_data.size() / 4,
+ filter_data.begin() + filter_data.size() / 2, 2.0);
+ std::fill(filter_data.begin() + filter_data.size() / 2,
+ filter_data.begin() + 3 * filter_data.size() / 4, 3.0);
+ std::fill(filter_data.begin() + 3 * filter_data.size() / 4, filter_data.end(),
+ 4.0);
+ const Array4D<float> filter_array(2, 2, 8, 8, filter_data);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(1, 2, 1, 1, {14240, 30496});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(2 * 2 * 8 * 8);
+ std::iota(input_data.begin(), input_data.end(), 0.0);
+ Array4D<float> input_array(2, 2, 8, 8, input_data);
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ std::vector<float> filter_data(2 * 2 * 8 * 8);
+ std::fill(filter_data.begin(), filter_data.begin() + filter_data.size() / 4,
+ 1.0);
+ std::fill(filter_data.begin() + filter_data.size() / 4,
+ filter_data.begin() + filter_data.size() / 2, 2.0);
+ std::fill(filter_data.begin() + filter_data.size() / 2,
+ filter_data.begin() + 3 * filter_data.size() / 4, 3.0);
+ std::fill(filter_data.begin() + 3 * filter_data.size() / 4, filter_data.end(),
+ 4.0);
+ const Array4D<float> filter_array(2, 2, 8, 8, filter_data);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(2, 2, 1, 1, {14240, 30496, 38816, 87840});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(32 * 2 * 8 * 8);
+ std::iota(input_data.begin(), input_data.end(), 0.0);
+ Array4D<float> input_array(32, 2, 8, 8, input_data);
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ std::vector<float> filter_data(2 * 2 * 8 * 8);
+ std::fill(filter_data.begin(), filter_data.begin() + filter_data.size() / 4,
+ 1.0);
+ std::fill(filter_data.begin() + filter_data.size() / 4,
+ filter_data.begin() + filter_data.size() / 2, 2.0);
+ std::fill(filter_data.begin() + filter_data.size() / 2,
+ filter_data.begin() + 3 * filter_data.size() / 4, 3.0);
+ std::fill(filter_data.begin() + 3 * filter_data.size() / 4, filter_data.end(),
+ 4.0);
+ const Array4D<float> filter_array(2, 2, 8, 8, filter_data);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ std::vector<float> expected_data = {
+ 14240, 30496, 38816, 87840, 63392, 145184, 87968,
+ 202528, 112544, 259872, 137120, 317216, 161696, 374560,
+ 186272, 431904, 210848, 489248, 235424, 546592, 260000,
+ 603936, 284576, 661280, 309152, 718624, 333728, 775968,
+ 358304, 833312, 382880, 890656, 407456, 948000, 432032,
+ 1005344, 456608, 1062688, 481184, 1120032, 505760, 1177376,
+ 530336, 1.23472e+06, 554912, 1292064, 579488, 1349408, 604064,
+ 1406752, 628640, 1464096, 653216, 1.52144e+06, 677792, 1578784,
+ 702368, 1636128, 726944, 1693472, 751520, 1750816, 776096,
+ 1.80816e+06,
+ };
+ Array4D<float> expected(32, 2, 1, 1, expected_data);
+ // The output elements can be larger than 1e+5, making the absolute error
+ // large sometimes. So, we focus on relative errors for this test case.
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter16x16x1x1Input16x16x1x1) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(16, 16, 1, 1);
+ Array4D<float> filter_array(16, 16, 1, 1);
+ for (int i0 = 0; i0 < 16; ++i0) {
+ for (int i1 = 0; i1 < 16; ++i1) {
+ input_array(i0, i1, 0, 0) = 1000 * i0 + i1;
+ filter_array(i0, i1, 0, 0) = 1;
+ }
+ }
+
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(16, 16, 1, 1);
+ for (int i0 = 0; i0 < 16; ++i0) {
+ for (int i1 = 0; i1 < 16; ++i1) {
+ expected(i0, i1, 0, 0) = 16000 * i0 + 120;
+ }
+ }
+
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ConvolutionVariantsTest, FlatRhsDilation) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(1 * 1 * 4 * 6);
+ std::iota(input_data.begin(), input_data.end(), 0.0);
+ Array4D<float> input_array(1, 1, 4, 6, input_data);
+
+ Array4D<float> filter_array(1, 1, 2, 3, {1, 10, 100, 2, 20, 200});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.ConvGeneralDilated(
+ /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{},
+ /*lhs_dilation=*/{}, /*rhs_dilation=*/{2, 2},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+
+ Array4D<float> expected(1, 1, 2, 2, {3924, 4257, 5922, 6255});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation1D) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(1 * 1 * 1 * 5);
+ std::iota(input_data.begin(), input_data.end(), 1.0);
+ Array4D<float> input_array(1, 1, 1, 5, input_data);
+
+ Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.ConvGeneralDilated(
+ /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{},
+ /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+
+ Array4D<float> expected(1, 1, 1, 8, {10, 2, 20, 3, 30, 4, 40, 5});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(1 * 1 * 3 * 4);
+ std::iota(input_data.begin(), input_data.end(), 1.0);
+ Array4D<float> input_array(1, 1, 3, 4, input_data);
+
+ Array4D<float> filter_array(1, 1, 4, 3, {100, 10, 1, //
+ 200, 20, 2, //
+ 300, 30, 3, //
+ 400, 40, 4});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.ConvGeneralDilated(
+ /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{2, 1},
+ /*padding=*/{{1, 0}, {0, 0}}, /*lhs_dilation=*/{3, 2},
+ /*rhs_dilation=*/{},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+
+ Array4D<float> expected(1, 1, 3, 5, {204, 40, 406, 60, 608, //
+ 1518, 180, 1821, 210, 2124, //
+ 4146, 460, 4651, 510, 5156});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingOnBothEnds) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(1 * 1 * 1 * 5);
+ std::iota(input_data.begin(), input_data.end(), 1.0);
+ Array4D<float> input_array(1, 1, 1, 5, input_data);
+
+ Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.ConvGeneral(
+ /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{},
+ /*padding=*/{{0, 0}, {-1, -1}},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+
+ Array4D<float> expected(1, 1, 1, 2, {23, 34});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingLowAndPositivePaddingHigh) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(1 * 1 * 1 * 5);
+ std::iota(input_data.begin(), input_data.end(), 1.0);
+ Array4D<float> input_array(1, 1, 1, 5, input_data);
+
+ Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.ConvGeneral(
+ /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{},
+ /*padding=*/{{0, 0}, {-1, 2}},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+
+ Array4D<float> expected(1, 1, 1, 5, {23, 34, 45, 50, 0});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingLowAndNegativePaddingHigh) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(1 * 1 * 1 * 5);
+ std::iota(input_data.begin(), input_data.end(), 1.0);
+ Array4D<float> input_array(1, 1, 1, 5, input_data);
+
+ Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.ConvGeneral(
+ /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{},
+ /*padding=*/{{0, 0}, {2, -1}},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+
+ Array4D<float> expected(1, 1, 1, 5, {0, 1, 12, 23, 34});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingAndDilation) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(1 * 1 * 1 * 5);
+ std::iota(input_data.begin(), input_data.end(), 1.0);
+ Array4D<float> input_array(1, 1, 1, 5, input_data);
+
+ Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.ConvGeneralDilated(
+ /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{},
+ /*padding=*/{{0, 0}, {3, 2}},
+ /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{1, 2},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+
+ // input:
+ // [1, 2, 3, 4, 5] --dilate-> [1, 0, 2, 0, 3, 0, 4, 0, 5]
+ // ---pad---> [0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 0]
+ // filter:
+ // [10, 1] --dilate-> [10, 0, 1]
+ Array4D<float> expected(1, 1, 1, 12,
+ {0, 1, 0, 12, 0, 23, 0, 34, 0, 45, 0, 50});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingAndDilation) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(1 * 1 * 1 * 5);
+ std::iota(input_data.begin(), input_data.end(), 1.0);
+ Array4D<float> input_array(1, 1, 1, 5, input_data);
+
+ Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.ConvGeneralDilated(
+ /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{},
+ /*padding=*/{{0, 0}, {-3, -2}},
+ /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{1, 2},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+
+ // input:
+ // [1, 2, 3, 4, 5] --dilate-> [1, 0, 2, 0, 3, 0, 4, 0, 5]
+ // ---pad---> [0, 3, 0, 4]
+ // filter:
+ // [10, 1] --dilate-> [10, 0, 1]
+ Array4D<float> expected(1, 1, 1, 2, {0, 34});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, RandomData_Input1x1x2x3_Filter2x1x1x2) {
+ constexpr int bs = 1;
+ constexpr int iz = 1;
+ constexpr int oz = 2;
+ constexpr int iy = 2;
+ constexpr int ix = 3;
+ constexpr int ky = 1;
+ constexpr int kx = 2;
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ std::vector<float> input_data(bs * iz * iy * ix);
+ for (float& f : input_data) {
+ f = distribution(rng);
+ }
+ std::vector<float> kernel_data(oz * iz * ky * kx);
+ for (float& f : kernel_data) {
+ f = distribution(rng);
+ }
+
+ Array4D<float> input_array(bs, iz, iy, ix, input_data);
+ Array4D<float> filter_array(oz, iz, ky, kx, kernel_data);
+
+ ComputationBuilder builder(client_, TestName());
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ std::unique_ptr<Array4D<float>> expected = ReferenceUtil::ConvArray4D(
+ input_array, filter_array, {1, 1}, Padding::kValid);
+
+ ComputeAndCompareR4<float>(&builder, *expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, RandomData_Input1x16x1x1_Filter1x16x1x1) {
+ constexpr int bs = 1;
+ constexpr int iz = 16;
+ constexpr int oz = 1;
+ constexpr int iy = 1;
+ constexpr int ix = 1;
+ constexpr int ky = 1;
+ constexpr int kx = 1;
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ std::vector<float> input_data(bs * iz * iy * ix);
+ for (float& f : input_data) {
+ f = distribution(rng);
+ }
+ std::vector<float> kernel_data(oz * iz * ky * kx);
+ for (float& f : kernel_data) {
+ f = distribution(rng);
+ }
+
+ Array4D<float> input_array(bs, iz, iy, ix, input_data);
+ Array4D<float> filter_array(oz, iz, ky, kx, kernel_data);
+
+ ComputationBuilder builder(client_, TestName());
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ std::unique_ptr<Array4D<float>> expected = ReferenceUtil::ConvArray4D(
+ input_array, filter_array, {1, 1}, Padding::kValid);
+
+ ComputeAndCompareR4<float>(&builder, *expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter1x16x1x1) {
+ constexpr int bs = 16;
+ constexpr int iz = 16;
+ constexpr int oz = 1;
+ constexpr int iy = 1;
+ constexpr int ix = 1;
+ constexpr int ky = 1;
+ constexpr int kx = 1;
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ std::vector<float> input_data(bs * iz * iy * ix);
+ for (float& f : input_data) {
+ f = distribution(rng);
+ }
+ std::vector<float> kernel_data(oz * iz * ky * kx);
+ for (float& f : kernel_data) {
+ f = distribution(rng);
+ }
+
+ Array4D<float> input_array(bs, iz, iy, ix, input_data);
+ Array4D<float> filter_array(oz, iz, ky, kx, kernel_data);
+
+ ComputationBuilder builder(client_, TestName());
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ std::unique_ptr<Array4D<float>> expected = ReferenceUtil::ConvArray4D(
+ input_array, filter_array, {1, 1}, Padding::kValid);
+
+ ComputeAndCompareR4<float>(&builder, *expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter16x16x1x1) {
+ constexpr int bs = 16;
+ constexpr int iz = 16;
+ constexpr int oz = 16;
+ constexpr int iy = 1;
+ constexpr int ix = 1;
+ constexpr int ky = 1;
+ constexpr int kx = 1;
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ std::vector<float> input_data(bs * iz * iy * ix);
+ for (float& f : input_data) {
+ f = distribution(rng);
+ }
+ std::vector<float> kernel_data(oz * iz * ky * kx);
+ for (float& f : kernel_data) {
+ f = distribution(rng);
+ }
+
+ Array4D<float> input_array(bs, iz, iy, ix, input_data);
+ Array4D<float> filter_array(oz, iz, ky, kx, kernel_data);
+
+ ComputationBuilder builder(client_, TestName());
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ std::unique_ptr<Array4D<float>> expected = ReferenceUtil::ConvArray4D(
+ input_array, filter_array, {1, 1}, Padding::kValid);
+
+ ComputeAndCompareR4<float>(&builder, *expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x16x16_Filter16x16x16x16) {
+ constexpr int bs = 16;
+ constexpr int iz = 16;
+ constexpr int oz = 16;
+ constexpr int iy = 16;
+ constexpr int ix = 16;
+ constexpr int ky = 16;
+ constexpr int kx = 16;
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ std::vector<float> input_data(bs * iz * iy * ix);
+ for (float& f : input_data) {
+ f = distribution(rng);
+ }
+ std::vector<float> kernel_data(oz * iz * ky * kx);
+ for (float& f : kernel_data) {
+ f = distribution(rng);
+ }
+
+ Array4D<float> input_array(bs, iz, iy, ix, input_data);
+ Array4D<float> filter_array(oz, iz, ky, kx, kernel_data);
+
+ ComputationBuilder builder(client_, TestName());
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ std::unique_ptr<Array4D<float>> expected = ReferenceUtil::ConvArray4D(
+ input_array, filter_array, {1, 1}, Padding::kValid);
+
+ ComputeAndCompareR4<float>(&builder, *expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(1 * 2 * 3 * 1);
+ std::iota(input_data.begin(), input_data.end(), 1.0);
+ Array4D<float> input_array(1, 2, 3, 1, input_data);
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ std::vector<float> filter_data(1 * 2 * 1 * 1);
+ std::iota(filter_data.begin(), filter_data.end(), 1.0);
+ Array4D<float> filter_array(1, 2, 1, 1, filter_data);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ ConvolutionDimensionNumbers dnums;
+ // NHWC input format.
+ dnums.set_batch_dimension(0);
+ dnums.add_spatial_dimensions(1);
+ dnums.add_spatial_dimensions(2);
+ dnums.set_feature_dimension(3);
+
+ // Tensorflow filter shape: [ H, W, inC, outC ]
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.add_kernel_spatial_dimensions(1);
+ dnums.set_kernel_input_feature_dimension(2);
+ dnums.set_kernel_output_feature_dimension(3);
+
+ // Tests padding sizes that don't correspond either to SAME or VALID padding.
+ builder.ConvGeneral(input, filter, {1, 1}, {{2, 1}, {2, 3}}, dnums);
+
+ std::vector<float> expected_data = {
+ 0, 0, 0, 0, 0, 0, 0, //
+ 0, 0, 0, 0, 0, 0, 0, //
+ 0, 2, 5, 8, 3, 0, 0, //
+ 0, 8, 14, 17, 6, 0, 0, //
+ 0, 0, 0, 0, 0, 0, 0 //
+ };
+ Array4D<float> expected(1, 5, 7, 1, expected_data);
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(1 * 2 * 3 * 1);
+ std::iota(input_data.begin(), input_data.end(), 1.0);
+ Array4D<float> input_array(1, 2, 3, 1, input_data);
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ std::vector<float> filter_data(1 * 1 * 1 * 1);
+ std::iota(filter_data.begin(), filter_data.end(), 2.0);
+ Array4D<float> filter_array(1, 1, 1, 1, filter_data);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ ConvolutionDimensionNumbers dnums;
+ // NHWC input format.
+ dnums.set_batch_dimension(0);
+ dnums.add_spatial_dimensions(1);
+ dnums.add_spatial_dimensions(2);
+ dnums.set_feature_dimension(3);
+
+ // Tensorflow filter shape: [ H, W, inC, outC ]
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.add_kernel_spatial_dimensions(1);
+ dnums.set_kernel_input_feature_dimension(2);
+ dnums.set_kernel_output_feature_dimension(3);
+
+ // Tests padding sizes that don't correspond either to SAME or VALID padding.
+ builder.ConvGeneral(input, filter, {1, 1}, {{2, 1}, {2, 3}}, dnums);
+
+ std::vector<float> expected_data = {
+ 0, 0, 0, 0, 0, 0, 0, 0, //
+ 0, 0, 0, 0, 0, 0, 0, 0, //
+ 0, 0, 2, 4, 6, 0, 0, 0, //
+ 0, 0, 8, 10, 12, 0, 0, 0, //
+ 0, 0, 0, 0, 0, 0, 0, 0 //
+ };
+ Array4D<float> expected(1, 5, 8, 1, expected_data);
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(1 * 2 * 3 * 1);
+ std::iota(input_data.begin(), input_data.end(), 1.0);
+ Array4D<float> input_array(1, 2, 3, 1, input_data);
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ std::vector<float> filter_data(1 * 1 * 1 * 1);
+ std::iota(filter_data.begin(), filter_data.end(), 2.0);
+ Array4D<float> filter_array(1, 1, 1, 1, filter_data);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ ConvolutionDimensionNumbers dnums;
+ // NHWC input format.
+ dnums.set_batch_dimension(0);
+ dnums.add_spatial_dimensions(1);
+ dnums.add_spatial_dimensions(2);
+ dnums.set_feature_dimension(3);
+
+ // Tensorflow filter shape: [ H, W, inC, outC ]
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.add_kernel_spatial_dimensions(1);
+ dnums.set_kernel_input_feature_dimension(2);
+ dnums.set_kernel_output_feature_dimension(3);
+
+ // Tests zero padding sizes. This can use matmul for computation.
+ builder.ConvGeneral(input, filter, {1, 1}, {{0, 0}, {0, 0}}, dnums);
+
+ std::vector<float> expected_data = {
+ 2, 4, 6, //
+ 8, 10, 12,
+ };
+ Array4D<float> expected(1, 2, 3, 1, expected_data);
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(1 * 2 * 3 * 2);
+ std::iota(input_data.begin(), input_data.end(), 1.0);
+ Array4D<float> input_array(1, 2, 3, 2, input_data);
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ std::vector<float> filter_data(1 * 1 * 2 * 3);
+ std::iota(filter_data.begin(), filter_data.end(), 2.0);
+ Array4D<float> filter_array(1, 1, 2, 3, filter_data);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ ConvolutionDimensionNumbers dnums;
+ // NHWC input format.
+ dnums.set_batch_dimension(0);
+ dnums.add_spatial_dimensions(1);
+ dnums.add_spatial_dimensions(2);
+ dnums.set_feature_dimension(3);
+
+ // Tensorflow filter shape: [ H, W, inC, outC ]
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.add_kernel_spatial_dimensions(1);
+ dnums.set_kernel_input_feature_dimension(2);
+ dnums.set_kernel_output_feature_dimension(3);
+
+ // Tests zero padding sizes. This can use matmul for computation.
+ builder.ConvGeneral(input, filter, {1, 1}, {{0, 0}, {0, 0}}, dnums);
+
+ std::vector<float> expected_data = {
+ 12, 15, 18, //
+ 26, 33, 40, //
+ 40, 51, 62, //
+ 54, 69, 84, //
+ 68, 87, 106, //
+ 82, 105, 128, //
+ };
+ Array4D<float> expected(1, 2, 3, 3, expected_data);
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+// Regression test for b/32034796.
+//
+// XLA:GPU fuses
+// Conv([1,2,3], Reverse([5,6]), padding_low=1)
+// into
+// BackwardInputConv([1,2,3], [5,6], padding_low=0, padding_high=1)
+TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingLessThanHighPadding) {
+ ComputationBuilder builder(client_, TestName());
+
+ auto gradients = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 3, /*values=*/{1, 2, 3}));
+ auto weights = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 2, /*values=*/{5, 6}));
+ auto mirrored_weights = builder.Rev(weights, {2, 3});
+ builder.ConvWithGeneralPadding(gradients, mirrored_weights,
+ /*window_strides=*/{1, 1},
+ /*padding=*/{{0, 0}, {1, 0}});
+ ComputeAndCompareR4<float>(&builder, {{{{5, 16, 27}}}}, {}, error_spec_);
+}
+
+// XLA:GPU fuses
+// Conv([1], Reverse([1,10,100]), padding_high=3, base_dilation=3)
+// into
+// BackwardInputConv([1], [1,10,100], stride=3, padding=(2,1))
+TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingGreaterThanHighPadding) {
+ ComputationBuilder builder(client_, TestName());
+
+ auto gradients = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 1, /*values=*/{1}));
+ auto weights = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 3, /*values=*/{1, 10, 100}));
+ auto mirrored_weights = builder.Rev(weights, {2, 3});
+ builder.ConvGeneralDilated(
+ gradients, mirrored_weights,
+ /*window_strides=*/{1, 1},
+ /*padding=*/{{0, 0}, {0, 3}},
+ /*lhs_dilation=*/{1, 3}, /*rhs_dilation=*/{},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+ ComputeAndCompareR4<float>(&builder, {{{{100, 0}}}}, {}, error_spec_);
+}
+
+// XLA:GPU fuses
+// Conv([1], Reverse([1,10,100]), padding=(1,1))
+// into
+// BackwardInputConv([1], [1,10,100], padding=(1,1))
+TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) {
+ ComputationBuilder builder(client_, TestName());
+
+ auto gradients = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 1, /*values=*/{1}));
+ auto weights = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 3, /*values=*/{1, 10, 100}));
+ auto mirrored_weights = builder.Rev(weights, {2, 3});
+ builder.ConvWithGeneralPadding(gradients, mirrored_weights,
+ /*window_strides=*/{1, 1},
+ /*padding=*/{{0, 0}, {1, 1}});
+ ComputeAndCompareR4<float>(&builder, {{{{10}}}}, {}, error_spec_);
+}
+
+// HLO pattern
+// Conv([1,2,3], Reverse([1,10], padding_high=2)
+// could be fused to
+// BackwardInputConv([1,2,3], [1,10], padding_low=1, padding_high=-1)
+//
+// However, XLA:GPU doesn't actually fuse it because PadInsertion doesn't
+// support negative padding on backward convolution yet (b/32744257).
+TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) {
+ ComputationBuilder builder(client_, TestName());
+
+ auto gradients = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 3, /*values=*/{1, 2, 3}));
+ auto weights = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 2, /*values=*/{1, 10}));
+ auto mirrored_weights = builder.Rev(weights, {2, 3});
+ builder.ConvWithGeneralPadding(gradients, mirrored_weights,
+ /*window_strides=*/{1, 1},
+ /*padding=*/{{0, 0}, {0, 2}});
+
+ ComputeAndCompareR4<float>(&builder, {{{{12, 23, 30, 0}}}}, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, BackwardFilterLowPaddingLessThanHighPadding) {
+ ComputationBuilder builder(client_, TestName());
+
+ // activations: 1,2,3,4 ---pad--> 0,1,2,3,4,0,0
+ // gradients: 100,10,1 -dilate-> 100,0,10,0,1
+ // weight gradients: 24,130,240
+ //
+ // This pattern will be fused to backward convolution with padding=(1,2).
+ auto activations = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 4, /*values=*/{1, 2, 3, 4}));
+ auto gradients = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 3, /*values=*/{100, 10, 1}));
+ auto forward_conv = builder.ConvGeneralDilated(
+ activations, gradients,
+ /*window_strides=*/{1, 1},
+ /*padding=*/{{0, 0}, {1, 2}},
+ /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+ builder.Transpose(forward_conv, {0, 1, 2, 3});
+
+ ComputeAndCompareR4<float>(&builder, {{{{24, 130, 240}}}}, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest,
+ BackwardFilterLowPaddingGreaterThanHighPadding) {
+ ComputationBuilder builder(client_, TestName());
+
+ // activations: 1,2,3,4 ---pad--> 0,0,1,2,3,4
+ // gradients: 100,10,1 -dilate-> 100,0,10,0,1
+ // weight gradients: 13,24
+ //
+ // This pattern will be fused to backward convolution with padding=(2,1).
+ // Note: both (2,1) and (2,0) are valid padding for the backward convolution
+ // because the stride is 2.
+ auto activations = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 4, /*values=*/{1, 2, 3, 4}));
+ auto gradients = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 3, /*values=*/{100, 10, 1}));
+ auto forward_conv = builder.ConvGeneralDilated(
+ activations, gradients,
+ /*window_strides=*/{1, 1},
+ /*padding=*/{{0, 0}, {2, 0}},
+ /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+ builder.Transpose(forward_conv, {0, 1, 2, 3});
+
+ ComputeAndCompareR4<float>(&builder, {{{{13, 24}}}}, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) {
+ ComputationBuilder builder(client_, TestName());
+
+ // activations: 1,2,3,4 ---pad--> 0,0,1,2,3,4,0
+ // gradients: 100,10,1 -dilate-> 100,0,10,0,1
+ // weight gradients: 13,24,130
+ //
+ // This pattern will be fused to backward convolution with padding=(2,2).
+ // Note: both (2,1) and (2,2) are valid padding for the backward convolution
+ // because the stride is 2. ConvolutionFolding prefers (2,2) because cuDNN
+ // supports even padding only -- using (2,1) would need extra effort of
+ // canonicalization.
+ auto activations = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 4, /*values=*/{1, 2, 3, 4}));
+ auto gradients = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 3, /*values=*/{100, 10, 1}));
+ auto forward_conv = builder.ConvGeneralDilated(
+ activations, gradients,
+ /*window_strides=*/{1, 1},
+ /*padding=*/{{0, 0}, {2, 1}},
+ /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+ builder.Transpose(forward_conv, {0, 1, 2, 3});
+
+ ComputeAndCompareR4<float>(&builder, {{{{13, 24, 130}}}}, {}, error_spec_);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc
new file mode 100644
index 0000000000..29e2950533
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/copy_test.cc
@@ -0,0 +1,277 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <utility>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class CopyOpTest : public HloTestBase {
+ protected:
+ void TestCopyOp(const Literal& literal) {
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(MakeUnique<Literal>(literal)));
+ builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kCopy, constant));
+ auto computation = builder.Build();
+ auto hlo_module = MakeUnique<HloModule>("test_module");
+ hlo_module->AddEntryComputation(std::move(computation));
+
+ std::unique_ptr<Literal> result =
+ ExecuteAndTransfer(std::move(hlo_module), {});
+ LiteralTestUtil::ExpectEqual(literal, *result);
+ }
+
+ void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3);
+ void TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, size_t n4,
+ tensorflow::gtl::ArraySlice<int64> permutation);
+};
+
+TEST_F(CopyOpTest, CopyR0Bool) {
+ TestCopyOp(*LiteralUtil::CreateR0<bool>(true));
+}
+
+TEST_F(CopyOpTest, CopyR1S0U32) {
+ TestCopyOp(*LiteralUtil::CreateR1<uint32>({}));
+}
+
+TEST_F(CopyOpTest, CopyR1S3U32) {
+ TestCopyOp(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
+}
+
+TEST_F(CopyOpTest, CopyR3F32_2x2x3) {
+ TestCopyOp(
+ *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+}
+
+TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) {
+ TestCopyOp(*LiteralUtil::CreateR4(
+ {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
+ {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
+}
+
+TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) {
+ TestCopyOp(*LiteralUtil::CreateR4FromArray4D(Array4D<int32>(0, 2, 3, 2)));
+}
+
+TEST_F(CopyOpTest, CopyParameterScalar) {
+ auto builder = HloComputation::Builder(TestName());
+
+ // Copy literal to device to use as parameter.
+ auto literal = LiteralUtil::CreateR0<float>(42.0);
+ Shape shape = literal->shape();
+ auto constant_device_base = TransferToDevice(*literal);
+
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param0"));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kCopy, param0));
+
+ auto computation = builder.Build();
+
+ auto hlo_module = MakeUnique<HloModule>("test_module");
+ hlo_module->AddEntryComputation(std::move(computation));
+
+ std::unique_ptr<Literal> result =
+ ExecuteAndTransfer(std::move(hlo_module), {constant_device_base});
+ LiteralTestUtil::ExpectR0Near<float>(42.0f, *result, error_spec_);
+}
+
+TEST_F(CopyOpTest, CopyConstantR2Twice) {
+ auto builder = HloComputation::Builder(TestName());
+
+ auto literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(std::move(literal)));
+
+ auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kCopy, constant));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, copy));
+
+ auto computation = builder.Build();
+
+ auto hlo_module = MakeUnique<HloModule>("test_module");
+ hlo_module->AddEntryComputation(std::move(computation));
+ std::unique_ptr<Literal> result =
+ ExecuteAndTransfer(std::move(hlo_module), {});
+ LiteralTestUtil::ExpectR2Near<float>({{1.0, 2.0}, {3.0, 4.0}}, *result,
+ error_spec_);
+}
+
+TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
+ HloComputation::Builder builder(TestName());
+
+ std::unique_ptr<Literal> literal =
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ // Reverse the minor-to-major order of the literal.
+ Layout* literal_layout = literal->mutable_shape()->mutable_layout();
+ ASSERT_EQ(2, literal_layout->minor_to_major_size());
+ literal_layout->mutable_minor_to_major()->SwapElements(0, 1);
+
+ HloInstruction* constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(std::move(literal)));
+
+ builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kCopy, constant));
+
+ std::unique_ptr<HloComputation> computation = builder.Build();
+
+ auto hlo_module = MakeUnique<HloModule>("test_module");
+ hlo_module->AddEntryComputation(std::move(computation));
+ std::unique_ptr<Literal> result =
+ ExecuteAndTransfer(std::move(hlo_module), {});
+
+ // The result of the computation has the default layout, which is the inverse
+ // of the layout of the source literal.
+ LiteralTestUtil::ExpectR2Near<float>({{1.0, 3.0}, {2.0, 4.0}}, *result,
+ error_spec_);
+}
+
+void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) {
+ Array3D<int32> a(n1, n2, n3);
+ for (size_t i = 0; i < n1; ++i) {
+ for (size_t j = 0; j < n2; ++j) {
+ for (size_t k = 0; k < n3; ++k) {
+ a(i, j, k) = i * n3 * n2 + j * n3 + k;
+ }
+ }
+ }
+
+ HloComputation::Builder builder(TestName());
+
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR3FromArray3D(a);
+
+ HloInstruction* constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(std::move(literal)));
+
+ builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kCopy, constant));
+
+ std::unique_ptr<HloComputation> computation = builder.Build();
+
+ auto hlo_module = MakeUnique<HloModule>("test_module");
+ auto config = MakeUnique<HloModuleConfig>(computation->ComputeProgramShape());
+ *config->mutable_entry_computation_layout()->mutable_result_layout() =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(
+ constant->shape().element_type(),
+ AsInt64Slice(constant->shape().dimensions()), {1, 2, 0}));
+ hlo_module->AddEntryComputation(std::move(computation));
+ std::unique_ptr<Literal> result =
+ ExecuteAndTransfer(std::move(hlo_module), std::move(config), {});
+
+ LiteralTestUtil::ExpectR3EqualArray3D(a, *result);
+}
+
+void CopyOpTest::TestCopyConstantLayoutR4(
+ size_t n1, size_t n2, size_t n3, size_t n4,
+ tensorflow::gtl::ArraySlice<int64> permutation) {
+ Array4D<int32> a(n1, n2, n3, n4);
+ for (size_t i = 0; i < n1; ++i) {
+ for (size_t j = 0; j < n2; ++j) {
+ for (size_t k = 0; k < n3; ++k) {
+ for (size_t l = 0; l < n4; ++l) {
+ a(i, j, k, l) = i * n4 * n3 * n2 + j * n4 * n3 + k * n4 + l;
+ }
+ }
+ }
+ }
+
+ HloComputation::Builder builder(TestName());
+
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR4FromArray4D(a);
+
+ HloInstruction* constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(std::move(literal)));
+
+ builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kCopy, constant));
+
+ std::unique_ptr<HloComputation> computation = builder.Build();
+
+ auto hlo_module = MakeUnique<HloModule>("test_module");
+ auto config = MakeUnique<HloModuleConfig>(computation->ComputeProgramShape());
+ *config->mutable_entry_computation_layout()->mutable_result_layout() =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(
+ constant->shape().element_type(),
+ AsInt64Slice(constant->shape().dimensions()), ({
+ std::vector<int64> p(permutation.rbegin(), permutation.rend());
+ p;
+ })));
+ hlo_module->AddEntryComputation(std::move(computation));
+ std::unique_ptr<Literal> result =
+ ExecuteAndTransfer(std::move(hlo_module), std::move(config), {});
+
+ LiteralTestUtil::ExpectR4EqualArray4D(a, *result);
+}
+
+XLA_TEST_F(CopyOpTest, CopyConstantR3Layout021_SingleIncompleteTilePerLayer) {
+ TestCopyConstantLayout021(2, 2, 3);
+}
+
+XLA_TEST_F(CopyOpTest, CopyConstantR3Layout021_SingleCompleteTilePerLayer) {
+ TestCopyConstantLayout021(2, 32, 32);
+}
+
+XLA_TEST_F(CopyOpTest, CopyConstantR3Layout021_MultipleTilesPerLayer) {
+ TestCopyConstantLayout021(2, 70, 35);
+}
+
+XLA_TEST_F(CopyOpTest, CopyConstantR4Layout0231_MultipleTilesPerLayer) {
+ TestCopyConstantLayoutR4(2, 70, 7, 5, {0, 2, 3, 1});
+}
+
+XLA_TEST_F(CopyOpTest, CopyConstantR4Layout0312_MultipleTilesPerLayer) {
+ TestCopyConstantLayoutR4(2, 14, 5, 35, {0, 3, 1, 2});
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc
new file mode 100644
index 0000000000..dc54c9defe
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/custom_call_test.cc
@@ -0,0 +1,148 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <utility>
+
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/dynamic_annotations.h"
+#include "tensorflow/core/platform/test.h"
+
+extern "C" void __attribute__((visibility("default")))
+R0F32Add2(float* out, float** in) {
+ TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float*));
+ *out = **in + 2.0f;
+}
+
+extern "C" void __attribute__((visibility("default")))
+R2F32ReduceSum(float* out, float** in) {
+ TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4);
+ float* array = in[0];
+ *out = array[0] + array[1] + array[2] + array[3];
+}
+
+extern "C" void __attribute__((visibility("default")))
+Add1ToValues(float* out, float** in) {
+ TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4);
+ float* array = in[0];
+ out[0] = array[0] + 1;
+ out[1] = array[1] + 1;
+ out[2] = array[2] + 1;
+ out[3] = array[3] + 1;
+}
+
+namespace xla {
+namespace {
+
+class CustomCallTest : public HloTestBase {
+ protected:
+ Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
+ Shape r2f32_ = ShapeUtil::MakeShape(F32, {2, 2});
+};
+
+XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) {
+ auto hlo_module = MakeUnique<HloModule>("test_module");
+ auto builder = HloComputation::Builder(TestName());
+
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ builder.AddInstruction(
+ HloInstruction::CreateCustomCall(r0f32_, {constant}, "R0F32Add2"));
+
+ hlo_module->AddEntryComputation(builder.Build());
+
+ std::unique_ptr<Literal> result =
+ ExecuteAndTransfer(std::move(hlo_module), {});
+ LiteralTestUtil::ExpectR0Near<float>(44.0f, *result, error_spec_);
+}
+
+XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
+ auto hlo_module = MakeUnique<HloModule>("test_module");
+ auto builder = HloComputation::Builder(TestName());
+
+ Array2D<float> array(2, 2);
+ array(0, 0) = 1.0f;
+ array(0, 1) = 2.0f;
+ array(1, 0) = 3.0f;
+ array(1, 1) = 4.0f;
+
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D(array)));
+ builder.AddInstruction(
+ HloInstruction::CreateCustomCall(r0f32_, {constant}, "R2F32ReduceSum"));
+
+ hlo_module->AddEntryComputation(builder.Build());
+
+ std::unique_ptr<Literal> result =
+ ExecuteAndTransfer(std::move(hlo_module), {});
+ LiteralTestUtil::ExpectR0Near<float>(10.0f, *result, error_spec_);
+}
+
+XLA_TEST_F(CustomCallTest,
+ DISABLED_ON_GPU(CustomCall_UsedInOtherComputations)) {
+ auto hlo_module = MakeUnique<HloModule>("test_module");
+ auto b = HloComputation::Builder(TestName());
+
+ auto input = b.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D(
+ Array2D<float>{{1.0f, 2.0f}, {3.0f, 4.0f}})));
+ auto incremented = b.AddInstruction(HloInstruction::CreateCustomCall(
+ ShapeUtil::MakeShape(F32, {1, 2, 2}), {input}, "Add1ToValues"));
+ auto incremented_again = b.AddInstruction(HloInstruction::CreateCustomCall(
+ ShapeUtil::MakeShape(F32, {1, 2, 2}), {incremented}, "Add1ToValues"));
+
+ // Concatenate the values along first dim.
+ b.AddInstruction(
+ HloInstruction::CreateConcatenate(ShapeUtil::MakeShape(F32, {2, 2, 2}),
+ {incremented, incremented_again}, 0));
+
+ hlo_module->AddEntryComputation(b.Build());
+
+ std::unique_ptr<Literal> result =
+ ExecuteAndTransfer(std::move(hlo_module), {});
+ LiteralTestUtil::ExpectR3EqualArray3D<float>(
+ Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, *result);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc
new file mode 100644
index 0000000000..528efd2942
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/deallocation_test.cc
@@ -0,0 +1,155 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class DeallocationTest : public ClientLibraryTestBase {
+ protected:
+ // Build and execute the given computation then verify the results can be
+ // transferred from the device successfully.
+ std::unique_ptr<GlobalData> ExecuteAndCheckTransfer(
+ ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ Computation computation = builder->Build().ConsumeValueOrDie();
+ auto global_data =
+ client_->Execute(computation, arguments).ConsumeValueOrDie();
+ TF_CHECK_OK(client_->Transfer(*global_data).status());
+ return global_data;
+ }
+};
+
+TEST_F(DeallocationTest, DeallocateScalar) {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR0<float>(42.0);
+ auto global_data = ExecuteAndCheckTransfer(&builder, {});
+
+ // A result can be transfered an arbitrary number of times. Add an extra
+ // transfer here so we're not just testing that a second call to Transfer
+ // fails.
+ ASSERT_IS_OK(client_->Transfer(*global_data).status());
+
+ ASSERT_IS_OK(client_->Unregister(*global_data));
+
+ auto transfer_status = client_->Transfer(*global_data);
+ ASSERT_FALSE(transfer_status.ok());
+ ASSERT_MATCH(transfer_status.status().error_message(),
+ testing::HasSubstr("was previously deallocated"));
+}
+
+TEST_F(DeallocationTest, DeallocateVector) {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
+ auto global_data = ExecuteAndCheckTransfer(&builder, {});
+
+ ASSERT_IS_OK(client_->Unregister(*global_data));
+
+ auto transfer_status = client_->Transfer(*global_data);
+ ASSERT_FALSE(transfer_status.ok());
+ ASSERT_MATCH(transfer_status.status().error_message(),
+ testing::HasSubstr("was previously deallocated"));
+}
+
+TEST_F(DeallocationTest, DeallocateEmptyVector) {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR1<float>({});
+ auto global_data = ExecuteAndCheckTransfer(&builder, {});
+
+ ASSERT_IS_OK(client_->Unregister(*global_data));
+
+ auto transfer_status = client_->Transfer(*global_data);
+ ASSERT_FALSE(transfer_status.ok());
+ ASSERT_MATCH(transfer_status.status().error_message(),
+ testing::HasSubstr("was previously deallocated"));
+}
+
+XLA_TEST_F(DeallocationTest, DeallocateTuple) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Tuple({builder.ConstantR0<float>(42.0),
+ builder.ConstantR1<float>({1.0, 2.0, 3.0})});
+ auto global_data = ExecuteAndCheckTransfer(&builder, {});
+
+ ASSERT_IS_OK(client_->Unregister(*global_data));
+
+ auto transfer_status = client_->Transfer(*global_data);
+ ASSERT_FALSE(transfer_status.ok());
+ ASSERT_MATCH(transfer_status.status().error_message(),
+ testing::HasSubstr("was previously deallocated"));
+}
+
+XLA_TEST_F(DeallocationTest, DeallocateTupleWithRepeatedElements) {
+ ComputationBuilder builder(client_, TestName());
+ auto element = builder.ConstantR0<float>(42.0);
+ auto inner_tuple = builder.Tuple({builder.ConstantR0<float>(42.0), element});
+ builder.Tuple({element, inner_tuple, element});
+ auto global_data = ExecuteAndCheckTransfer(&builder, {});
+
+ ASSERT_IS_OK(client_->Unregister(*global_data));
+
+ auto transfer_status = client_->Transfer(*global_data);
+ ASSERT_FALSE(transfer_status.ok());
+ ASSERT_MATCH(transfer_status.status().error_message(),
+ testing::HasSubstr("was previously deallocated"));
+}
+
+XLA_TEST_F(DeallocationTest, DeallocateNestedTuple) {
+ ComputationBuilder builder(client_, TestName());
+ auto inner_tuple =
+ builder.Tuple({builder.ConstantR0<float>(42.0),
+ builder.ConstantR1<float>({1.0, 2.0, 3.0})});
+ builder.Tuple({inner_tuple, builder.ConstantR1<float>({0.123, 0.456})});
+ auto global_data = ExecuteAndCheckTransfer(&builder, {});
+
+ ASSERT_IS_OK(client_->Unregister(*global_data));
+
+ auto transfer_status = client_->Transfer(*global_data);
+ ASSERT_FALSE(transfer_status.ok());
+ ASSERT_MATCH(transfer_status.status().error_message(),
+ testing::HasSubstr("was previously deallocated"));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
new file mode 100644
index 0000000000..57a7c61b14
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
@@ -0,0 +1,215 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class DeconstructTupleTest : public ClientLibraryTestBase {
+ protected:
+ // Build and execute the given computation then verify the results can be
+ // transferred from the device successfully.
+ std::unique_ptr<GlobalData> ExecuteAndCheckTransfer(
+ ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ Computation computation = builder->Build().ConsumeValueOrDie();
+ auto global_data =
+ client_->Execute(computation, arguments).ConsumeValueOrDie();
+ TF_CHECK_OK(client_->Transfer(*global_data).status());
+ return global_data;
+ }
+};
+
+TEST_F(DeconstructTupleTest, DeconstructTuple) {
+ ComputationBuilder builder(client_, TestName());
+ auto const1 = builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
+ auto const2 = builder.ConstantR1<float>({2.0, 4.0, 6.0, 8.0});
+ builder.Tuple({const1, const2});
+ auto global_data = ExecuteAndCheckTransfer(&builder, {});
+
+ auto result_status = client_->DeconstructTuple(*global_data);
+ EXPECT_TRUE(result_status.ok());
+
+ // Try copying the elements back and comparing it
+ auto handles = result_status.ConsumeValueOrDie();
+ std::vector<float> copy(4);
+ ASSERT_IS_OK(client_->TransferInProcess(*handles[0], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({1.0, 2.0, 3.0, 4.0}));
+ ASSERT_IS_OK(client_->TransferInProcess(*handles[1], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({2.0, 4.0, 6.0, 8.0}));
+}
+
+TEST_F(DeconstructTupleTest, DeconstructTupleTwice) {
+ ComputationBuilder builder(client_, TestName());
+ auto const1 = builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
+ auto const2 = builder.ConstantR1<float>({2.0, 4.0, 6.0, 8.0});
+ builder.Tuple({const1, const2});
+ auto global_data = ExecuteAndCheckTransfer(&builder, {});
+
+ auto result_status1 = client_->DeconstructTuple(*global_data);
+ EXPECT_TRUE(result_status1.ok());
+ auto result_status2 = client_->DeconstructTuple(*global_data);
+ EXPECT_TRUE(result_status2.ok());
+
+ auto handles1 = result_status1.ConsumeValueOrDie();
+ auto handles2 = result_status2.ConsumeValueOrDie();
+ std::vector<float> copy(4);
+
+ ASSERT_IS_OK(client_->TransferInProcess(*handles1[0], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({1.0, 2.0, 3.0, 4.0}));
+ ASSERT_IS_OK(client_->TransferInProcess(*handles1[1], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({2.0, 4.0, 6.0, 8.0}));
+ handles1[0].reset();
+ handles1[1].reset();
+
+ ASSERT_IS_OK(client_->TransferInProcess(*handles2[0], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({1.0, 2.0, 3.0, 4.0}));
+ ASSERT_IS_OK(client_->TransferInProcess(*handles2[1], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({2.0, 4.0, 6.0, 8.0}));
+}
+
+XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) {
+ ComputationBuilder builder(client_, TestName());
+ auto const1 = builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
+ auto const2 = builder.ConstantR1<float>({2.0, 4.0, 6.0, 8.0});
+ builder.Tuple({const1, const2, const2, const1});
+ auto global_data = ExecuteAndCheckTransfer(&builder, {});
+
+ auto result_status = client_->DeconstructTuple(*global_data);
+ EXPECT_TRUE(result_status.ok());
+
+ // Verify the returned GlobalDataHandle arrays have repeated elements like the
+ // tuple does. That is, in the returned vector of handles, handle[0] should be
+ // the same as handle[3] and handle[1] should be the same as handle[2].
+ auto handles = result_status.ConsumeValueOrDie();
+
+ std::vector<float> copy(4);
+ ASSERT_IS_OK(client_->TransferInProcess(*handles[0], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({1.0, 2.0, 3.0, 4.0}));
+ ASSERT_IS_OK(client_->TransferInProcess(*handles[1], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({2.0, 4.0, 6.0, 8.0}));
+ ASSERT_IS_OK(client_->TransferInProcess(*handles[2], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({2.0, 4.0, 6.0, 8.0}));
+ ASSERT_IS_OK(client_->TransferInProcess(*handles[3], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({1.0, 2.0, 3.0, 4.0}));
+}
+
+TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) {
+ ComputationBuilder builder(client_, TestName());
+ auto const1 = builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
+ auto const2 = builder.ConstantR1<float>({2.0, 4.0, 6.0, 8.0});
+ builder.Tuple({const1, const2, const1});
+ auto global_data = ExecuteAndCheckTransfer(&builder, {});
+
+ auto result_status = client_->DeconstructTuple(*global_data);
+ EXPECT_TRUE(result_status.ok());
+ auto handles = result_status.ConsumeValueOrDie();
+
+ // Deallocate the tuple, then try copying the elements back. The elements
+ // should not have been deallocated because of reference counting.
+ global_data.reset();
+
+ std::vector<float> copy(4);
+ ASSERT_IS_OK(client_->TransferInProcess(*handles[0], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({1.0, 2.0, 3.0, 4.0}));
+ ASSERT_IS_OK(client_->TransferInProcess(*handles[1], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({2.0, 4.0, 6.0, 8.0}));
+ ASSERT_IS_OK(client_->TransferInProcess(*handles[2], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({1.0, 2.0, 3.0, 4.0}));
+
+ /// Try deallocating one of the repeated elements, then copy
+ handles[0].reset();
+
+ ASSERT_IS_OK(client_->TransferInProcess(*handles[2], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({1.0, 2.0, 3.0, 4.0}));
+}
+
+TEST_F(DeconstructTupleTest, DeconstructNonTuple) {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
+ auto global_data = ExecuteAndCheckTransfer(&builder, {});
+
+ auto result_status = client_->DeconstructTuple(*global_data);
+ EXPECT_FALSE(result_status.ok());
+ EXPECT_MATCH(result_status.status().error_message(),
+ testing::ContainsRegex("global data handle .* is not a tuple"));
+}
+
+XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) {
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({3.14f, -100.25f});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "param0");
+ builder.Tuple({p});
+ auto global_data = ExecuteAndCheckTransfer(&builder, {param0_data.get()});
+
+ auto result_status = client_->DeconstructTuple(*global_data);
+ EXPECT_TRUE(result_status.ok());
+ auto handles = result_status.ConsumeValueOrDie();
+ EXPECT_NE(handles[0]->handle().handle(), param0_data->handle().handle());
+}
+
+XLA_TEST_F(DeconstructTupleTest, DeconstructNestedTuple) {
+ ComputationBuilder builder(client_, TestName());
+ auto const1 = builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
+ auto const2 = builder.ConstantR1<float>({2.0, 4.0, 6.0, 8.0});
+ builder.Tuple({builder.Tuple({const1, const2}), const1});
+ auto global_data = ExecuteAndCheckTransfer(&builder, {});
+
+ auto result_status = client_->DeconstructTuple(*global_data);
+ EXPECT_FALSE(result_status.ok());
+ EXPECT_MATCH(
+ result_status.status().error_message(),
+ testing::ContainsRegex("deconstructing nested tuples not yet supported"));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
new file mode 100644
index 0000000000..da2d43ca4f
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -0,0 +1,387 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/layout_util_flags.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace {
+
+// TODO(mfdyck): use GUnit typed tests when we can do all tests on all backends.
+class DotOperationTest : public ClientLibraryTestBase {
+ public:
+ ErrorSpec error_spec_{0.0001, 1e-5};
+
+ protected:
+ template <typename Element>
+ void TestOneElementVectorDot();
+ template <typename Element>
+ void TestVectorDot();
+ template <typename Element>
+ void TestSquareMatrixDot(bool lhs_row_major = false,
+ bool rhs_row_major = false);
+ template <typename Element>
+ void TestNonsquareMatrixDot(bool lhs_row_major = false,
+ bool rhs_row_major = false);
+};
+
+XLA_TEST_F(DotOperationTest, ZeroElementVectorDotF32) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<float>({});
+ auto rhs = builder.ConstantR1<float>({});
+ auto result = builder.Dot(lhs, rhs);
+
+ ComputeAndCompareR0<float>(&builder, 0.0, {}, error_spec_);
+}
+
+template <typename Element>
+void DotOperationTest::TestOneElementVectorDot() {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<Element>({2.0});
+ auto rhs = builder.ConstantR1<Element>({3.0});
+ auto result = builder.Dot(lhs, rhs);
+
+ ComputeAndCompareR0<Element>(&builder, 6.0, {}, error_spec_);
+}
+
+XLA_TEST_F(DotOperationTest, OneElementVectorDotF32) {
+ TestOneElementVectorDot<float>();
+}
+
+XLA_TEST_F(DotOperationTest, OneElementVectorDotF64) {
+ TestOneElementVectorDot<double>();
+}
+
+template <typename Element>
+void DotOperationTest::TestVectorDot() {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<Element>({1.0, 2.5, 42.0});
+ auto rhs = builder.ConstantR1<Element>({11.0, -1.0, 0.5});
+ auto result = builder.Dot(lhs, rhs);
+
+ ComputeAndCompareR0<Element>(&builder, 29.5, {}, error_spec_);
+}
+
+XLA_TEST_F(DotOperationTest, VectorDotF32) { TestVectorDot<float>(); }
+
+XLA_TEST_F(DotOperationTest, VectorDotF64) { TestVectorDot<double>(); }
+
+namespace {
+
+std::vector<int64> MinorToMajorForIsRowMajor(bool row_major) {
+ return {row_major ? 1 : 0, row_major ? 0 : 1};
+}
+
+} // namespace
+
+XLA_TEST_F(DotOperationTest, Dot_0x2_2x0) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
+ auto rhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(2, 0));
+ auto result = builder.Dot(lhs, rhs);
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {}, error_spec_);
+}
+
+XLA_TEST_F(DotOperationTest, Dot_0x2_2x3) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
+ auto rhs = builder.ConstantR2<float>({{7.0, 8.0, 9.0}, {42.0, 77.0, 101.0}});
+ auto result = builder.Dot(lhs, rhs);
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 3), {}, error_spec_);
+}
+
+XLA_TEST_F(DotOperationTest, Dot_3x2_2x0) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs =
+ builder.ConstantR2<float>({{7.0, 8.0}, {9.0, 42.0}, {77.0, 101.0}});
+ auto rhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(2, 0));
+ auto result = builder.Dot(lhs, rhs);
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(3, 0), {}, error_spec_);
+}
+
+XLA_TEST_F(DotOperationTest, Dot_2x0_0x2) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(2, 0));
+ auto rhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
+ auto result = builder.Dot(lhs, rhs);
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(2, 2, 0.0f), {},
+ error_spec_);
+}
+
+template <typename Element>
+void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major,
+ bool rhs_row_major) {
+ auto lhs_handle =
+ client_
+ ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
+ {{1.0, 2.0}, {3.0, -4.0}},
+ MinorToMajorForIsRowMajor(lhs_row_major)))
+ .ConsumeValueOrDie();
+ auto rhs_handle =
+ client_
+ ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
+ {{1.0, 6.0}, {7.0, -4.0}},
+ MinorToMajorForIsRowMajor(rhs_row_major)))
+ .ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto prim_type = primitive_util::NativeToPrimitiveType<Element>();
+ auto result = builder.Dot(
+ builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"),
+ builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs"));
+
+ Array2D<Element> expected({{15.0, -2.0}, {-25.0, 34.0}});
+ ComputeAndCompareR2<Element>(
+ &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_);
+}
+
+XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) {
+ constexpr bool kLhsRowMajor = false;
+ constexpr bool kRhsRowMajor = false;
+ TestSquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor);
+}
+
+XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFT) {
+ TestSquareMatrixDot<float>(false, true);
+}
+
+XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTF) {
+ TestSquareMatrixDot<float>(true, false);
+}
+
+TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTT) {
+ constexpr bool kLhsRowMajor = true;
+ constexpr bool kRhsRowMajor = true;
+ TestSquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor);
+}
+
+XLA_TEST_F(DotOperationTest, SquareMatrixDotF64) {
+ TestSquareMatrixDot<double>();
+}
+
+template <typename Element>
+void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major,
+ bool rhs_row_major) {
+ auto lhs_handle =
+ client_
+ ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
+ {{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}},
+ MinorToMajorForIsRowMajor(lhs_row_major)))
+ .ConsumeValueOrDie();
+ auto rhs_handle =
+ client_
+ ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
+ {{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}},
+ MinorToMajorForIsRowMajor(rhs_row_major)))
+ .ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto prim_type = primitive_util::NativeToPrimitiveType<Element>();
+ auto result = builder.Dot(
+ builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"),
+ builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs"));
+
+ Array2D<Element> expected({{26.0, 0.0}, {-12.0, 10.0}});
+
+ ComputeAndCompareR2<Element>(
+ &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_);
+}
+
+XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFF) {
+ constexpr bool kLhsRowMajor = false;
+ constexpr bool kRhsRowMajor = false;
+ TestNonsquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor);
+}
+
+XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFT) {
+ constexpr bool kLhsRowMajor = false;
+ constexpr bool kRhsRowMajor = true;
+ TestNonsquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor);
+}
+
+XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTF) {
+ constexpr bool kLhsRowMajor = true;
+ constexpr bool kRhsRowMajor = false;
+ TestNonsquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor);
+}
+
+TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) {
+ constexpr bool kLhsRowMajor = true;
+ constexpr bool kRhsRowMajor = true;
+ TestNonsquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor);
+}
+
+XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF64) {
+ TestNonsquareMatrixDot<double>();
+}
+
+TEST_F(DotOperationTest, ConcurrentMatMul) {
+ ComputationBuilder builder(client_, TestName());
+ auto matrix1 = builder.ConstantR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto matrix2 = builder.ConstantR2<float>({{5.0, 6.0}, {7.0, 8.0}});
+ auto matrix12 = builder.Dot(matrix1, matrix2);
+ auto matrix21 = builder.Dot(matrix2, matrix1);
+ builder.Add(matrix12, matrix21);
+
+ Array2D<float> expected({{42.0, 56.0}, {74.0, 96.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// Regression test for b/32055648. The root of the graph is a kFusion of 4
+// bitcasts. Although bitcasts don't map to thunks, the root should still be
+// sync-dependent on bitcasts' operands.
+XLA_TEST_F(DotOperationTest, BatchMatMul) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "x");
+ auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "y");
+
+ auto x_flat = builder.Reshape(x, {0, 1, 2, 3}, {4, 2, 2});
+ auto y_flat = builder.Reshape(y, {0, 1, 2, 3}, {4, 2, 2});
+
+ // Slice batches into individual matrices and multiply them.
+ std::vector<xla::ComputationDataHandle> out_slices;
+ for (int i = 0; i < 4; ++i) {
+ // Slice off individual matrices and reshape to 2D tensors.
+ auto x_slice = builder.Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2});
+ x_slice = builder.Reshape(x_slice, {0, 1, 2}, {2, 2});
+ auto y_slice = builder.Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2});
+ y_slice = builder.Reshape(y_slice, {0, 1, 2}, {2, 2});
+
+ auto out = builder.Dot(x_slice, y_slice);
+ out = builder.Reshape(out, {0, 1}, {1, 2, 2});
+ out_slices.push_back(out);
+ }
+ auto out_flat = builder.ConcatInDim(out_slices, 0);
+ builder.Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2});
+
+ auto x_data = client_
+ ->TransferToServer(*LiteralUtil::CreateR4<float>(
+ {{{{1000, 100}, {10, 1}}, {{2000, 200}, {20, 2}}},
+ {{{3000, 300}, {30, 3}}, {{4000, 400}, {40, 4}}}}))
+ .ConsumeValueOrDie();
+ auto y_data = client_
+ ->TransferToServer(*LiteralUtil::CreateR4<float>(
+ {{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}},
+ {{{11, 22}, {33, 44}}, {{55, 66}, {77, 88}}}}))
+ .ConsumeValueOrDie();
+
+ ComputeAndCompareR4<float>(
+ &builder,
+ /*expected=*/{{{{1300, 2400}, {13, 24}}, {{11400, 13600}, {114, 136}}},
+ {{{42900, 79200}, {429, 792}},
+ {{250800, 299200}, {2508, 2992}}}},
+ {x_data.get(), y_data.get()}, error_spec_);
+}
+
+TEST_F(DotOperationTest, TransposeFolding) {
+ for (bool transpose_lhs : {false, true}) {
+ for (bool transpose_rhs : {false, true}) {
+ for (bool row_major : {false, true}) {
+ std::unique_ptr<Array2D<float>> lhs(
+ new Array2D<float>({{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}}));
+ std::unique_ptr<Array2D<float>> rhs(
+ new Array2D<float>({{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}}));
+
+ if (transpose_lhs) {
+ lhs = ReferenceUtil::TransposeArray2D(*lhs);
+ }
+ if (transpose_rhs) {
+ rhs = ReferenceUtil::TransposeArray2D(*rhs);
+ }
+ auto lhs_handle =
+ client_
+ ->TransferToServer(
+ *LiteralUtil::CreateR2FromArray2DWithLayout<float>(
+ *lhs, LayoutUtil::MakeLayout(
+ MinorToMajorForIsRowMajor(row_major))))
+ .ConsumeValueOrDie();
+ auto rhs_handle =
+ client_
+ ->TransferToServer(
+ *LiteralUtil::CreateR2FromArray2DWithLayout<float>(
+ *rhs, LayoutUtil::MakeLayout(
+ MinorToMajorForIsRowMajor(row_major))))
+ .ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto prim_type = primitive_util::NativeToPrimitiveType<float>();
+ auto lhs_arg = builder.Parameter(
+ 0, ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}),
+ "lhs");
+ auto rhs_arg = builder.Parameter(
+ 1, ShapeUtil::MakeShape(prim_type, {rhs->height(), rhs->width()}),
+ "rhs");
+ if (transpose_lhs) {
+ lhs_arg = builder.Transpose(lhs_arg, {1, 0});
+ }
+ if (transpose_rhs) {
+ rhs_arg = builder.Transpose(rhs_arg, {1, 0});
+ }
+ auto result = builder.Dot(lhs_arg, rhs_arg);
+
+ Array2D<float> expected({{26.0, 0.0}, {-12.0, 10.0}});
+ VLOG(1) << "TestTransposeFolding " << transpose_lhs << " "
+ << transpose_rhs << " " << row_major;
+ ComputeAndCompareR2<float>(&builder, expected,
+ {lhs_handle.get(), rhs_handle.get()},
+ error_spec_);
+ }
+ }
+ }
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendLayoutUtilFlags(&flag_list);
+ xla::legacy_flags::AppendCpuRuntimeFlags(&flag_list);
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
new file mode 100644
index 0000000000..cecc4872df
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
@@ -0,0 +1,506 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <numeric>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/local_service.h"
+#include "tensorflow/compiler/xla/service/platform_util.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+namespace {
+
+class DynamicSliceTest : public ClientLibraryTestBase {
+ protected:
+ template <typename IndexT>
+ void TestR1() {
+ // Slice at dimension start.
+ RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {0}, {5},
+ {0.0, 1.0, 2.0, 3.0, 4.0});
+ // Slice in the middle.
+ RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {2}, {3},
+ {2.0, 3.0, 4.0});
+ // Slice at dimension boundaries.
+ RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {5}, {3},
+ {5.0, 6.0, 7.0});
+ // Slice at dimension boundaries, but with sizes that cause indices to wrap.
+ RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {6}, {4},
+ {6.0, 7.0, 0.0, 1.0});
+ }
+
+ template <typename IndexT>
+ void TestR2() {
+ // Slice at dimension start.
+ RunR2<IndexT>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
+ {0, 0}, {2, 2}, {{1.0f, 2.0f}, {4.0f, 5.0f}});
+ // Slice in the middle.
+ RunR2<IndexT>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
+ {1, 1}, {2, 1}, {{5.0f}, {8.0f}});
+ // Slice at dimension boundaries.
+ RunR2<IndexT>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
+ {1, 1}, {2, 1}, {{5.0f}, {8.0f}});
+ // Slice at dimension boundaries, but with sizes that cause indices to wrap.
+ RunR2<IndexT>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
+ {1, 1}, {3, 3},
+ {{5.0f, 6.0f, 4.0f}, {8.0f, 9.0f, 7.0f}, {2.0f, 3.0f, 1.0f}});
+ }
+
+ template <typename IndexT>
+ void TestR3() {
+ // R3 Shape: [2, 3, 2]
+ // clang-format off
+
+ // Slice at dimension start.
+ RunR3<IndexT>(
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
+ {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}},
+ {0, 0, 0}, {2, 1, 2},
+ {{{1.0f, 2.0f}}, {{7.0f, 8.0f}}});
+
+ // Slice in the middle.
+ RunR3<IndexT>(
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
+ {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}},
+ {0, 1, 1}, {2, 2, 1},
+ {{{4.0f}, {6.0f}}, {{10.0f}, {12.0f}}});
+
+ // Slice at dimension boundaries, but with sizes that cause indices to wrap.
+ RunR3<IndexT>(
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
+ {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}},
+ {0, 2, 1}, {2, 2, 1},
+ {{{6.0f}, {2.0f}}, {{12.0f}, {8.0f}}});
+
+ // clang-format on
+ }
+
+ template <typename IndexT>
+ void RunR1(const std::vector<float>& input_values,
+ const std::vector<IndexT> slice_starts,
+ const std::vector<int64> slice_sizes,
+ const std::vector<float>& expected_values) {
+ ComputationBuilder builder(client_, TestName());
+ // Initialize and transfer dynamic slice start indices parameter.
+ ComputationDataHandle starts;
+ std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
+ slice_starts, 0, "slice_starts", &builder, &starts);
+ // Build dynamic slice computation.
+ auto input = builder.ConstantR1<float>(input_values);
+ builder.DynamicSlice(input, starts, slice_sizes);
+ // Run computation and compare against expected values.
+ ComputeAndCompareR1<float>(&builder, expected_values, {start_data.get()},
+ ErrorSpec(0.000001));
+ }
+
+ template <typename IndexT>
+ void RunR2(const Array2D<float>& input_values,
+ const std::vector<IndexT> slice_starts,
+ const std::vector<int64> slice_sizes,
+ const Array2D<float>& expected_values) {
+ ComputationBuilder builder(client_, TestName());
+ // Initialize and transfer dynamic slice start indices parameter.
+ ComputationDataHandle starts;
+ std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
+ slice_starts, 0, "slice_starts", &builder, &starts);
+ // Build dynamic slice computation.
+ auto input = builder.ConstantR2FromArray2D<float>(input_values);
+ builder.DynamicSlice(input, starts, slice_sizes);
+ // Run computation and compare against expected values.
+ ComputeAndCompareR2<float>(&builder, expected_values, {start_data.get()},
+ ErrorSpec(0.000001));
+ }
+
+ template <typename IndexT>
+ void RunR3(const Array3D<float>& input_values,
+ const std::vector<IndexT> slice_starts,
+ const std::vector<int64> slice_sizes,
+ const Array3D<float>& expected_values) {
+ ComputationBuilder builder(client_, TestName());
+ // Initialize and transfer dynamic slice start indices parameter.
+ ComputationDataHandle starts;
+ std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
+ slice_starts, 0, "slice_starts", &builder, &starts);
+ // Build dynamic slice computation.
+ auto input = builder.ConstantR3FromArray3D<float>(input_values);
+ builder.DynamicSlice(input, starts, slice_sizes);
+ // Run computation and compare against expected values.
+ ComputeAndCompareR3<float>(&builder, expected_values, {start_data.get()},
+ ErrorSpec(0.000001));
+ }
+};
+
+XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1<int32>(); }
+
+XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1<int64>(); }
+
+XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1<uint64>(); }
+
+XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2<int32>(); }
+
+XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2<int64>(); }
+
+XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2<uint64>(); }
+
+XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3<int32>(); }
+
+XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3<int64>(); }
+
+XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3<uint64>(); }
+
+class DynamicUpdateSliceTest : public ClientLibraryTestBase {
+ protected:
+ template <typename IndexT>
+ void TestR1() {
+ // clang-format off
+ // Slice at dimension start.
+ RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0},
+ {8.0, 9.0, 10.0}, {0},
+ {8.0, 9.0, 10.0, 3.0, 4.0, 5.0, 6.0, 7.0});
+ // Slice in the middle.
+ RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0},
+ {8.0, 9.0, 10.0}, {2},
+ {0.0, 1.0, 8.0, 9.0, 10.0, 5.0, 6.0, 7.0});
+ // Slice at dimension boundaries.
+ RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0},
+ {8.0, 9.0, 10.0}, {5},
+ {0.0, 1.0, 2.0, 3.0, 4.0, 8.0, 9.0, 10.0});
+ // Slice at dimension boundaries, but with sizes that cause indices to wrap.
+ RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0},
+ {8.0, 9.0, 10.0}, {6},
+ {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 8.0, 9.0});
+ // clang-format on
+ }
+
+ template <typename IndexT>
+ void TestR2() {
+ // clang-format off
+ // Slice at dimension start.
+ RunR2<IndexT>(
+ {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
+ {{10.0f, 11.0f}}, {0, 0},
+ {{10.0f, 11.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}});
+ // Slice in the middle.
+ RunR2<IndexT>(
+ {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
+ {{10.0f, 11.0f}}, {1, 1},
+ {{1.0f, 2.0f, 3.0f}, {4.0f, 10.0f, 11.0f}, {7.0f, 8.0f, 9.0f}});
+ // Slice at dimension boundaries.
+ RunR2<IndexT>(
+ {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
+ {{10.0f, 11.0f}}, {2, 1},
+ {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 10.0f, 11.0f}});
+ // Slice at dimension boundaries, but with sizes that cause indices to wrap.
+ RunR2<IndexT>(
+ {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
+ {{10.0f, 11.0f}}, {2, 2},
+ {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 10.0f}});
+ // clang-format on
+ }
+
+ template <typename IndexT>
+ void TestR3() {
+ // R3 Shape: [2, 3, 2]
+ // clang-format off
+ // Slice at dimension start.
+ RunR3<IndexT>(
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
+ {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}},
+ {{{13.0f, 14.0f}, {15.0f, 16.0f}},
+ {{17.0f, 18.0f}, {19.0f, 20.0f}}},
+ {0, 0, 0},
+ {{{13.0f, 14.0f}, {15.0f, 16.0f}, {5.0f, 6.0f}},
+ {{17.0f, 18.0f}, {19.0f, 20.0f}, {11.0f, 12.0f}}});
+ // Slice in the middle.
+ RunR3<IndexT>(
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
+ {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}},
+ {{{13.0f}, {15.0f}}},
+ {1, 1, 1},
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
+ {{7.0f, 8.0f}, {9.0f, 13.0f}, {11.0f, 15.0f}}});
+ // Slice at dimension boundaries, but with sizes that cause indices to wrap.
+ RunR3<IndexT>(
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
+ {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}},
+ {{{13.0f}, {15.0f}}},
+ {1, 2, 1},
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
+ {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 13.0f}}});
+ // clang-format on
+ }
+
+ template <typename IndexT>
+ void RunR1(const std::vector<float>& input_values,
+ const std::vector<float>& update_values,
+ const std::vector<IndexT> slice_starts,
+ const std::vector<float>& expected_values) {
+ ComputationBuilder builder(client_, TestName());
+ // Initialize and transfer dynamic slice start indices parameter.
+ ComputationDataHandle starts;
+ std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
+ slice_starts, 0, "slice_starts", &builder, &starts);
+ // Build dynamic slice computation.
+ auto input = builder.ConstantR1<float>(input_values);
+ auto update = builder.ConstantR1<float>(update_values);
+ builder.DynamicUpdateSlice(input, update, starts);
+ // Run computation and compare against expected values.
+ ComputeAndCompareR1<float>(&builder, expected_values, {start_data.get()},
+ ErrorSpec(0.000001));
+ }
+
+ template <typename IndexT>
+ void RunR2(const Array2D<float>& input_values,
+ const Array2D<float>& update_values,
+ const std::vector<IndexT> slice_starts,
+ const Array2D<float>& expected_values) {
+ ComputationBuilder builder(client_, TestName());
+ // Initialize and transfer dynamic slice start indices parameter.
+ ComputationDataHandle starts;
+ std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
+ slice_starts, 0, "slice_starts", &builder, &starts);
+ // Build dynamic slice computation.
+ auto input = builder.ConstantR2FromArray2D<float>(input_values);
+ auto update = builder.ConstantR2FromArray2D<float>(update_values);
+ builder.DynamicUpdateSlice(input, update, starts);
+ // Run computation and compare against expected values.
+ ComputeAndCompareR2<float>(&builder, expected_values, {start_data.get()},
+ ErrorSpec(0.000001));
+ }
+
+ template <typename IndexT>
+ void RunR3(const Array3D<float>& input_values,
+ const Array3D<float>& update_values,
+ const std::vector<IndexT> slice_starts,
+ const Array3D<float>& expected_values) {
+ ComputationBuilder builder(client_, TestName());
+ // Initialize and transfer dynamic slice start indices parameter.
+ ComputationDataHandle starts;
+ std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
+ slice_starts, 0, "slice_starts", &builder, &starts);
+ // Build dynamic slice computation.
+ auto input = builder.ConstantR3FromArray3D<float>(input_values);
+ auto update = builder.ConstantR3FromArray3D<float>(update_values);
+ builder.DynamicUpdateSlice(input, update, starts);
+ // Run computation and compare against expected values.
+ ComputeAndCompareR3<float>(&builder, expected_values, {start_data.get()},
+ ErrorSpec(0.000001));
+ }
+
+ void RunR3Contiguous(std::vector<int32> operand_shape, int32 index,
+ int32 size) {
+ const int32 kSeq = operand_shape[0];
+ const int32 kBatch = operand_shape[1];
+ const int32 kDim = operand_shape[2];
+ Array3D<float> input_values(kSeq, kBatch, kDim);
+ Array3D<float> update_values(size, kBatch, kDim);
+ Array3D<float> expected_values(kSeq, kBatch, kDim);
+
+ input_values.FillIota(0);
+ float val = 1000;
+ update_values.FillIota(val);
+
+ // TODO(b/34128753) Expected values may vary depending on backend when
+ // the update wraps. According to documentation, the results are technically
+ // implementation specific where the update is out of bounds, and hence
+ // we don't really know what to pass into ComputeAndCompareR3.
+ expected_values.FillIota(0);
+ for (int i = 0; i < size; i++) {
+ for (int j = 0; j < kBatch; j++) {
+ for (int k = 0; k < kDim; k++) {
+ expected_values((index + i) % kSeq, j, k) = val++;
+ }
+ }
+ }
+ if (VLOG_IS_ON(1)) {
+ DumpArray<float>("input", input_values);
+ DumpArray<float>("update", update_values);
+ DumpArray<float>("expected", expected_values);
+ }
+
+ // Build dynamic slice computation.
+ ComputationBuilder builder(client_, TestName());
+ auto starts = builder.ConstantR1<int32>({index, 0, 0});
+ auto input = builder.ConstantR3FromArray3D<float>(input_values);
+ auto update = builder.ConstantR3FromArray3D<float>(update_values);
+ builder.DynamicUpdateSlice(input, update, starts);
+
+ // Run computation and compare against expected values.
+ ComputeAndCompareR3<float>(&builder, expected_values, {},
+ ErrorSpec(0.000001));
+ }
+
+ template <typename NativeT>
+ void DumpArray(const string& name, const Array3D<NativeT> values) {
+ std::unique_ptr<Literal> literal =
+ LiteralUtil::CreateR3FromArray3D<NativeT>(values);
+ LOG(INFO) << name << ":" << LiteralUtil::ToString(*literal);
+ }
+};
+
+XLA_TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1<int32>(); }
+
+XLA_TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1<int64>(); }
+
+XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1<uint64>(); }
+
+XLA_TEST_F(DynamicUpdateSliceTest, Int32R2) { TestR2<int32>(); }
+
+XLA_TEST_F(DynamicUpdateSliceTest, Int64R2) { TestR2<int64>(); }
+
+XLA_TEST_F(DynamicUpdateSliceTest, UInt64R2) { TestR2<uint64>(); }
+
+XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3<int32>(); }
+
+XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3<int64>(); }
+
+XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3<uint64>(); }
+
+// Tests for simple R3 case where the update is contiguous (i.e. the minor
+// two dimensions are not sliced).
+XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElement) {
+ // Single element, no wrap.
+ std::vector<int32> operand_shape({4, 5, 2});
+ RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1);
+}
+
+XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) {
+ // Multiple element, no wrap.
+ std::vector<int32> operand_shape({4, 5, 2});
+ RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2);
+}
+
+// TODO(b/34128753) CPU and GPU failed on 2016-01-06. Appears not to handle
+// wrapping as expected.
+XLA_TEST_F(DynamicUpdateSliceTest,
+ DISABLED_ON_CPU(DISABLED_ON_GPU(R3ContiguousMultipleWrapping))) {
+ // Multiple element, wrapping.
+ std::vector<int32> operand_shape({4, 5, 2});
+ RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2);
+}
+
+// TODO(b/34128753) CPU and GPU failed on 2016-01-06. Appears not to handle
+// wrapping as expected.
+XLA_TEST_F(DynamicUpdateSliceTest,
+ DISABLED_ON_CPU(DISABLED_ON_GPU(R3ContiguousTooLarge))) {
+ // Multiple element, update size larger than operand.
+ std::vector<int32> operand_shape({4, 5, 2});
+ RunR3Contiguous(operand_shape, /*index=*/5, /*size=*/2);
+}
+
+XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousUnaligned) {
+ std::vector<int32> operand_shape({3, 123, 247});
+ RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1);
+}
+
+// TODO(b/34134076) Disabled on GPU 2016-01-06 due to out-of-memory error.
+XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_GPU(R3ContiguousLarger)) {
+ std::vector<int32> operand_shape({32, 128, 1024});
+ RunR3Contiguous(operand_shape, /*index=*/7, /*size=*/1);
+}
+
+void BM_DynamicSlice(int num_iters) {
+ tensorflow::testing::StopTiming();
+
+ se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
+ auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
+ StreamExecutorMemoryAllocator allocator(platform, executors);
+ LocalClient* client =
+ ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
+ auto* transfer_manager =
+ TransferManager::GetForPlatform(platform).ValueOrDie();
+ int device_ordinal = client->default_device_ordinal();
+
+ ComputationBuilder builder(client, "DynamicSlice");
+
+ // Create input as a constant: shape [1, 2, 3, 4]
+ auto input_literal = LiteralUtil::CreateR4(
+ {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
+ {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
+ auto input = builder.ConstantLiteral(*input_literal);
+
+ // Create dynamic slice start indices as a parameter: shape [4]
+ auto start_indices_shape = ShapeUtil::MakeShape(S32, {4});
+ auto start_indices =
+ builder.Parameter(0, start_indices_shape, "start_indices");
+ // Add DynamicSlice op to the computatation.
+ builder.DynamicSlice(input, start_indices, {1, 1, 1, 1});
+ auto computation = builder.Build().ConsumeValueOrDie();
+
+ // Initialize and transfer parameter buffer.
+ auto buffer = ScopedShapedBuffer::MakeScopedShapedBuffer(start_indices_shape,
+ &allocator, 0)
+ .ConsumeValueOrDie();
+
+ auto start_indices_literal = LiteralUtil::CreateR1<int32>({0, 1, 2, 3});
+ ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(
+ executors[device_ordinal], *start_indices_literal,
+ buffer->mutable_buffer({})));
+
+ // Run some warm-up executions.
+ LocalExecuteOptions options;
+ options.set_allocator(&allocator);
+ const int kWarmups = 2;
+ for (int i = 0; i < kWarmups; ++i) {
+ auto result = client->ExecuteLocally(computation, {buffer.get()}, options);
+ ASSERT_TRUE(result.ok());
+ }
+
+ // Run benchmark.
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i < num_iters; ++i) {
+ auto result = client->ExecuteLocally(computation, {buffer.get()}, options);
+ ASSERT_TRUE(result.ok());
+ }
+}
+BENCHMARK(BM_DynamicSlice);
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc
new file mode 100644
index 0000000000..8e30063085
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc
@@ -0,0 +1,128 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <limits>
+#include <string>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class FloorCeilTest : public ClientLibraryTestBase {
+ public:
+ enum Function {
+ kFloor,
+ kCeil,
+ };
+
+ // Runs a computation and comparison on expected vs f(input)
+ void TestR1F32(tensorflow::gtl::ArraySlice<float> input,
+ tensorflow::gtl::ArraySlice<float> expected, Function f) {
+ LOG(INFO) << "input: {" << tensorflow::str_util::Join(expected, ", ")
+ << "}";
+ ComputationBuilder builder(client_, TestName());
+ auto c = builder.ConstantR1<float>(input);
+ if (f == kCeil) {
+ builder.Ceil(c);
+ } else {
+ ASSERT_EQ(kFloor, f);
+ builder.Floor(c);
+ }
+ ComputeAndCompareR1<float>(&builder, expected, /*arguments=*/{});
+ }
+
+ void TestR0F32(float input, float expected, Function f) {
+ LOG(INFO) << "input: " << expected;
+ ComputationBuilder builder(client_, TestName());
+ auto c = builder.ConstantR0<float>(input);
+ if (f == kCeil) {
+ builder.Ceil(c);
+ } else {
+ ASSERT_EQ(kFloor, f);
+ builder.Floor(c);
+ }
+ ComputeAndCompareR0<float>(&builder, expected, /*arguments=*/{});
+ }
+
+ const ErrorSpec error_spec_{0.0001};
+
+ float infinity_ = std::numeric_limits<float>::infinity();
+ float minus_infinity_ = -std::numeric_limits<float>::infinity();
+};
+
+// Interesting notes:
+// * if you pass snan the CPU doesn't canonicalize it to qnan.
+// * passing x86-based CPU's qnan to the GPU makes a different nan
+// "7fc00000=nan=nan vs 7fffffff=nan=nan"
+
+XLA_TEST_F(FloorCeilTest, R1S0Floor) { TestR1F32({}, {}, kFloor); }
+
+TEST_F(FloorCeilTest, R1Floor) {
+ TestR1F32({0.0, -0.0, infinity_, minus_infinity_, 1.1, -0.1},
+ {0.0, -0.0, infinity_, minus_infinity_, 1.0, -1.0}, kFloor);
+}
+
+TEST_F(FloorCeilTest, R1Ceil) {
+ TestR1F32({0.0, -0.0, infinity_, minus_infinity_, 1.1, -0.1},
+ {0.0, -0.0, infinity_, minus_infinity_, 2.0, -0.0}, kCeil);
+}
+
+TEST_F(FloorCeilTest, R0Floor) {
+ TestR0F32(0.0, 0.0, kFloor);
+ TestR0F32(-0.0, -0.0, kFloor);
+ TestR0F32(infinity_, infinity_, kFloor);
+ TestR0F32(minus_infinity_, minus_infinity_, kFloor);
+ TestR0F32(1.1, 1.0, kFloor);
+ TestR0F32(-0.1, -1.0, kFloor);
+}
+
+TEST_F(FloorCeilTest, R0Ceil) {
+ TestR0F32(0.0, 0.0, kCeil);
+ TestR0F32(-0.0, -0.0, kCeil);
+ TestR0F32(infinity_, infinity_, kCeil);
+ TestR0F32(minus_infinity_, minus_infinity_, kCeil);
+ TestR0F32(1.1, 2.0, kCeil);
+ TestR0F32(-0.1, -0.0, kCeil);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/fmax_test.cc b/tensorflow/compiler/xla/tests/fmax_test.cc
new file mode 100644
index 0000000000..2835038c90
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/fmax_test.cc
@@ -0,0 +1,61 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class FmaxSimpleTest : public ClientLibraryTestBase {};
+
+TEST_F(FmaxSimpleTest, FmaxTenValues) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<float>(
+ {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0});
+ auto y = builder.ConstantR1<float>(
+ {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0});
+ builder.Max(x, y);
+
+ std::vector<float> expected = {-0.0, 1.0, 2.0, 3.0, 4.0,
+ 5.0, 6.0, 7.0, 8.0, 9.0};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
new file mode 100644
index 0000000000..7bddbfa894
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -0,0 +1,589 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <math.h>
+#include <algorithm>
+#include <memory>
+#include <new>
+#include <utility>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+using tensorflow::gtl::ArraySlice;
+
+namespace xla {
+namespace {
+
+const int test_width = 2, test_height = 3;
+
+const float test_float_vals[3][test_width][test_height] = {
+ {{-1.0, -1.0, 1.0}, {-3.0, 0.0, -1.0}},
+ {{-3.0, 2.0, 1.0}, {0.0, -3.0, 1.0}},
+ {{-3.0, 0.0, -3.0}, {-1.0, -2.0, 1.0}}};
+
+// Test whether fusion operations are emitted with no errors and compute
+// accurate outputs.
+class FusionTest : public HloTestBase {
+ protected:
+ template <typename T, int Arity>
+ void TestElementwise2D(HloOpcode opcode) {
+ Array2D<float> operand_data[Arity];
+ for (int i = 0; i < Arity; ++i) {
+ new (&operand_data[i]) Array2D<float>(test_width, test_height);
+ }
+ Array2D<T> answer_data(test_width, test_height);
+ for (int i = 0; i < test_width; ++i) {
+ for (int j = 0; j < test_height; ++j) {
+ float xs[Arity];
+ for (int k = 0; k < Arity; ++k) {
+ xs[k] = test_float_vals[k][i][j];
+ operand_data[k](i, j) = xs[k];
+ }
+ answer_data(i, j) = ComputeElementwiseAnswer<T>(opcode, xs);
+ }
+ }
+
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+
+ auto prim_type = primitive_util::NativeToPrimitiveType<T>();
+
+ HloInstruction* hlos[4];
+ for (int i = 0; i < Arity; ++i) {
+ hlos[i + 1] = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2FromArray2D(operand_data[i])));
+ }
+ auto answer_shape =
+ ShapeUtil::MakeShape(prim_type, {test_width, test_height});
+ std::unique_ptr<HloInstruction> root_hlo;
+ switch (Arity) {
+ case 1:
+ root_hlo = HloInstruction::CreateUnary(answer_shape, opcode, hlos[1]);
+ break;
+ case 2:
+ root_hlo = HloInstruction::CreateBinary(answer_shape, opcode, hlos[1],
+ hlos[2]);
+ break;
+ case 3:
+ root_hlo = HloInstruction::CreateTernary(answer_shape, opcode, hlos[1],
+ hlos[2], hlos[3]);
+ break;
+ default:
+ LOG(FATAL) << "Bad arity: " << Arity;
+ }
+ hlos[0] = builder.AddInstruction(std::move(root_hlo));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(
+ ArraySlice<HloInstruction*>(hlos, 0, Arity + 1),
+ HloInstruction::FusionKind::kLoop);
+
+ auto expected = LiteralUtil::CreateR2FromArray2D(answer_data);
+ auto actual = ExecuteAndTransfer(std::move(hlo_module), {});
+ if (primitive_util::IsFloatingPointType(prim_type)) {
+ LiteralTestUtil::ExpectNear(*expected, *actual, ErrorSpec(1e-4));
+ } else {
+ LiteralTestUtil::ExpectEqual(*expected, *actual);
+ }
+ }
+
+ private:
+ template <typename T>
+ T ComputeElementwiseAnswer(HloOpcode opcode, ArraySlice<float> xs);
+};
+
+template <>
+float FusionTest::ComputeElementwiseAnswer<float>(HloOpcode opcode,
+ ArraySlice<float> xs) {
+ switch (opcode) {
+ case HloOpcode::kAdd:
+ return xs[0] + xs[1];
+ case HloOpcode::kSubtract:
+ return xs[0] - xs[1];
+ case HloOpcode::kMultiply:
+ return xs[0] * xs[1];
+ case HloOpcode::kDivide:
+ return xs[0] / xs[1];
+ case HloOpcode::kPower:
+ return powf(xs[0], xs[1]);
+ case HloOpcode::kMinimum:
+ return std::min(xs[0], xs[1]);
+ case HloOpcode::kMaximum:
+ return std::max(xs[0], xs[1]);
+ case HloOpcode::kClamp:
+ return std::min(xs[2], std::max(xs[1], xs[0]));
+ default:
+ LOG(FATAL) << "No elementwise opcode: " << opcode;
+ }
+}
+
+template <>
+uint8 FusionTest::ComputeElementwiseAnswer<uint8>(HloOpcode opcode,
+ ArraySlice<float> xs) {
+ switch (opcode) {
+ case HloOpcode::kEq:
+ return xs[0] == xs[1];
+ case HloOpcode::kNe:
+ return xs[0] != xs[1];
+ case HloOpcode::kGt:
+ return xs[0] > xs[1];
+ case HloOpcode::kLt:
+ return xs[0] < xs[1];
+ case HloOpcode::kGe:
+ return xs[0] >= xs[1];
+ case HloOpcode::kLe:
+ return xs[0] <= xs[1];
+ default:
+ LOG(FATAL) << "No comparatory opcode: " << opcode;
+ }
+}
+
+XLA_TEST_F(FusionTest, Test) {
+ // test expression:
+ // slice(select({{T, F, T}, {F, T, F}},
+ // concat(transpose({{1.0}, {2.0}, {3.0}} +
+ // {{-1.0}, {-1.0}, {-1.0}}),
+ // {{1.62, 2.72, 3.14}}) +
+ // (-{{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}),
+ // {{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})) = {{0.5}, {2.72}}
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.0}, {2.0}, {3.0}})));
+ auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{-1.0}, {-1.0}, {-1.0}})));
+ auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {3, 1}), HloOpcode::kAdd, const0, const1));
+ auto reshape3 = builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {1, 3}), add2, {1, 0}));
+ auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.62, 2.72, 3.14}})));
+ auto concat5 = builder.AddInstruction(HloInstruction::CreateConcatenate(
+ ShapeUtil::MakeShape(F32, {2, 3}), {reshape3, const4}, 0));
+ auto const6 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}})));
+ auto negate7 = builder.AddInstruction(HloInstruction::CreateUnary(
+ ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kNegate, const6));
+ auto add8 = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kAdd, concat5, negate7));
+ auto const9 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})));
+ auto const10 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2<bool>(
+ {{true, false, true}, {false, true, false}})));
+ auto select11 = builder.AddInstruction(
+ HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}),
+ HloOpcode::kSelect, const10, add8, const9));
+ auto slice12 = builder.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {2, 1}), select11, {0, 1}, {2, 2}));
+ // CreateFusionInstruction needs the `instructions_to_fuse` argument in
+ // reverse topological order, so the first element in `instructions_to_fuse`
+ // must be the root.
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(
+ {slice12, select11, const10, const9, add8, negate7, const6, concat5,
+ const4, reshape3, add2, const1, const0},
+ HloInstruction::FusionKind::kLoop);
+
+ LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR2<float>({{0.5}, {2.72}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}),
+ ErrorSpec(1e-4));
+}
+
+// Test whether we emit appropriate code for parameters of fusion instructions.
+XLA_TEST_F(FusionTest, Parameter) {
+ // Build a computation and fuse part of it so the fusion instruction has an
+ // operand parameter.
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}})));
+ auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary(
+ ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kCopy, const0));
+ auto const2 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{-2.0, -2.0, -2.0}})));
+ // add3 = copy1 + const2 = const0 + const2 = {1,2,3} + {-2,-2,-2} = {-1,0,+1}
+ auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kAdd, copy1, const2));
+ // CreateFusionInstruction needs `instructions_to_fuse` in reverse topological
+ // order.
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2},
+ HloInstruction::FusionKind::kLoop);
+
+ LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR2<float>({{-1.0, 0.0, 1.0}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}),
+ ErrorSpec(1e-4));
+}
+
+XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
+ auto const_array = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}})));
+ auto broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(const_array->shape(), const_vector, {1}));
+ // add2 = broadcast(const_vector) + const_array
+ // = broadcast({1,2,3}) + {{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}
+ // = {{1, 2, 3}, {1, 2, 3}} + {{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}
+ auto add2 = builder.AddInstruction(
+ HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {2, 3}),
+ HloOpcode::kAdd, broadcast, const_array));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{add2, broadcast},
+ HloInstruction::FusionKind::kLoop);
+
+ LiteralTestUtil::ExpectNear(
+ *LiteralUtil::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4));
+}
+
+XLA_TEST_F(FusionTest, ReshapeToScalar) {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto single_element_array = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2<int32>({{5}})));
+ auto reshape = builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(S32, {}), single_element_array));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape},
+ HloInstruction::FusionKind::kLoop);
+ LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR0<int32>(5),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}})));
+ auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(S32, {1, 2, 3}), const0));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
+ HloInstruction::FusionKind::kLoop);
+ LiteralTestUtil::ExpectEqual(
+ *LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}})));
+ auto reshape1 = builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 2}), const0));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
+ HloInstruction::FusionKind::kLoop);
+ LiteralTestUtil::ExpectEqual(
+ *LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR3<int32>({{{7}}})));
+ auto reshape1 = builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
+ HloInstruction::FusionKind::kLoop);
+ LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR0<int32>(7),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+XLA_TEST_F(FusionTest, Reshape__1by1by1) {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(7)));
+ auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(S32, {1, 1, 1}), const0));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
+ HloInstruction::FusionKind::kLoop);
+ LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR3<int32>({{{7}}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+XLA_TEST_F(FusionTest, Reshape__) {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(7)));
+ auto reshape1 = builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
+ HloInstruction::FusionKind::kLoop);
+ LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR0<int32>(7),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
+ auto reshape1 = builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 3}), const0));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
+ HloInstruction::FusionKind::kLoop);
+ LiteralTestUtil::ExpectEqual(
+ *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+XLA_TEST_F(FusionTest, Transpose_2by3) {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}})));
+ auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(S32, {3, 2}), const0, {1, 0}));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
+ HloInstruction::FusionKind::kLoop);
+ LiteralTestUtil::ExpectEqual(
+ *LiteralUtil::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+XLA_TEST_F(FusionTest, Transpose_3by3) {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
+ auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(S32, {3, 3}), const0, {1, 0}));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
+ HloInstruction::FusionKind::kLoop);
+ LiteralTestUtil::ExpectEqual(
+ *LiteralUtil::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+XLA_TEST_F(FusionTest, Reverse) {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1, 2, 3})));
+ auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse(
+ ShapeUtil::MakeShape(S32, {3}), const0, {0}));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1},
+ HloInstruction::FusionKind::kLoop);
+
+ LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR1<int32>({3, 2, 1}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+std::unique_ptr<HloComputation> MakeReduceTestComputation() {
+ auto builder = HloComputation::Builder("add");
+ auto lhs = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/0, ShapeUtil::MakeShape(S32, {}), "lhs"));
+ auto rhs = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/1, ShapeUtil::MakeShape(S32, {}), "rhs"));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, lhs, rhs));
+ return builder.Build();
+}
+
+XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int32>({1, 2, 4, 8})));
+ auto const1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
+ auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce(
+ ShapeUtil::MakeShape(S32, {}), const0, const1, {0},
+ hlo_module->AddEmbeddedComputation(MakeReduceTestComputation())));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2},
+ HloInstruction::FusionKind::kLoop);
+
+ LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR0<int32>(15),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int32>({1, 2, 4, 8})));
+ auto const1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
+ auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce(
+ ShapeUtil::MakeShape(S32, {}), const0, const1, {0},
+ hlo_module->AddEmbeddedComputation(MakeReduceTestComputation())));
+ auto negate3 = builder.AddInstruction(HloInstruction::CreateUnary(
+ ShapeUtil::MakeShape(S32, {1}), HloOpcode::kNegate, reduce2));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2},
+ HloInstruction::FusionKind::kLoop);
+
+ LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR1<int32>({-15}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<int32>({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}})));
+ auto const1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
+ Window window;
+ ASSERT_TRUE(
+ tensorflow::protobuf::TextFormat::ParseFromString("dimensions:{\n"
+ "size:2\n"
+ "stride:1\n"
+ "padding_low:0\n"
+ "padding_high:0\n"
+ "window_dilation:1\n"
+ "base_dilation:1\n"
+ "}\n"
+ "dimensions:{\n"
+ "size:2\n"
+ "stride:1\n"
+ "padding_low:0\n"
+ "padding_high:0\n"
+ "window_dilation:1\n"
+ "base_dilation:1\n"
+ "}\n",
+ &window));
+ auto nested_builder = HloComputation::Builder("mul");
+ {
+ auto x = nested_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "x"));
+ auto y = nested_builder.AddInstruction(
+ HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(S32, {}), "y"));
+ nested_builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply, x, y));
+ }
+ auto nested_computation =
+ hlo_module->AddEmbeddedComputation(nested_builder.Build());
+ auto reduce_window2 =
+ builder.AddInstruction(HloInstruction::CreateReduceWindow(
+ ShapeUtil::MakeShape(S32, {2, 2}), const0, const1, window,
+ nested_computation));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce_window2},
+ HloInstruction::FusionKind::kLoop);
+
+ LiteralTestUtil::ExpectEqual(
+ *LiteralUtil::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); }
+
+XLA_TEST_F(FusionTest, Subtract2D) {
+ TestElementwise2D<float, 2>(HloOpcode::kSubtract);
+}
+
+XLA_TEST_F(FusionTest, Multiply2D) {
+ TestElementwise2D<float, 2>(HloOpcode::kMultiply);
+}
+
+XLA_TEST_F(FusionTest, Divide2D) {
+ TestElementwise2D<float, 2>(HloOpcode::kDivide);
+}
+
+XLA_TEST_F(FusionTest, Power2D) {
+ TestElementwise2D<float, 2>(HloOpcode::kPower);
+}
+
+XLA_TEST_F(FusionTest, Minimum2D) {
+ TestElementwise2D<float, 2>(HloOpcode::kMinimum);
+}
+
+XLA_TEST_F(FusionTest, Maximum2D) {
+ TestElementwise2D<float, 2>(HloOpcode::kMaximum);
+}
+
+XLA_TEST_F(FusionTest, Equal2D) { TestElementwise2D<uint8, 2>(HloOpcode::kEq); }
+
+XLA_TEST_F(FusionTest, Inequal2D) {
+ TestElementwise2D<uint8, 2>(HloOpcode::kNe);
+}
+
+XLA_TEST_F(FusionTest, Greater2D) {
+ TestElementwise2D<uint8, 2>(HloOpcode::kGt);
+}
+
+XLA_TEST_F(FusionTest, Lesser2D) {
+ TestElementwise2D<uint8, 2>(HloOpcode::kLt);
+}
+
+XLA_TEST_F(FusionTest, GreaterOrEqual2D) {
+ TestElementwise2D<uint8, 2>(HloOpcode::kGe);
+}
+
+XLA_TEST_F(FusionTest, LesserOrEqual2D) {
+ TestElementwise2D<uint8, 2>(HloOpcode::kLe);
+}
+
+XLA_TEST_F(FusionTest, Clamp2D) {
+ TestElementwise2D<float, 3>(HloOpcode::kClamp);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
new file mode 100644
index 0000000000..872188de81
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -0,0 +1,204 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+
+#include <set>
+#include <string>
+#include <utility>
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
+#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/compiler/xla/shape_layout.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+
+// Define this in .cc file to avoid having to include eigen or forward declare
+// these types in the header.
+struct HloTestBase::EigenThreadPoolWrapper {
+ std::unique_ptr<EigenThreadPoolWrapper> pool;
+ std::unique_ptr<Eigen::ThreadPoolDevice> device;
+};
+
+HloTestBase::HloTestBase()
+ : backend_(Backend::CreateDefaultBackend().ConsumeValueOrDie()) {
+ test_hlo_dumper_ = [](const HloModule& module, const string& label) {
+ legacy_flags::HloTestBaseFlags* flags = legacy_flags::GetHloTestBaseFlags();
+ if (flags->xla_hlo_test_generate_hlo_graph) {
+ const bool show_addresses = true;
+ const bool show_layouts = true;
+ hlo_graph_dumper::DumpGraph(*module.entry_computation(), label,
+ show_addresses, show_layouts);
+ }
+ };
+ VLOG(1) << "executing on platform " << backend_->platform()->Name();
+}
+
+HloTestBase::~HloTestBase() {
+ // Deallocate all the memory allocated during the tests.
+ for (auto& allocation : allocations_) {
+ backend_->default_stream_executor()->Deallocate(&allocation);
+ }
+}
+
+StatusOr<perftools::gputools::DeviceMemoryBase> HloTestBase::Execute(
+ std::unique_ptr<HloModule> module,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments,
+ Shape* result_shape) {
+ auto module_config = MakeUnique<HloModuleConfig>(
+ MakeProgramShape(module->entry_computation()));
+ return Execute(std::move(module), std::move(module_config), arguments,
+ result_shape);
+}
+
+StatusOr<se::DeviceMemoryBase> HloTestBase::Execute(
+ std::unique_ptr<HloModule> hlo_module,
+ std::unique_ptr<HloModuleConfig> module_config,
+ tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
+ Shape* result_shape) {
+ VLOG(3) << "module_config layout "
+ << LayoutUtil::HumanString(module_config->entry_computation_layout()
+ .result_layout()
+ .layout());
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<Executable> executable,
+ backend_->compiler()->Compile(std::move(hlo_module),
+ std::move(module_config), test_hlo_dumper_,
+ backend_->default_stream_executor()));
+
+ se::Stream stream(backend_->default_stream_executor());
+ stream.Init();
+
+ ExecutableRunOptions run_options;
+ run_options.set_stream(&stream);
+ run_options.set_allocator(backend_->memory_allocator());
+ run_options.set_inter_op_thread_pool(backend_->inter_op_thread_pool());
+ run_options.set_intra_op_thread_pool(
+ backend_->eigen_intra_op_thread_pool_device());
+
+ HloExecutionProfile hlo_execution_profile;
+ TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase result,
+ executable->ExecuteOnStream(&run_options, arguments,
+ &hlo_execution_profile));
+ TF_RET_CHECK(stream.BlockHostUntilDone());
+
+ allocations_.push_back(result);
+
+ *result_shape = executable->result_shape();
+
+ if (ShapeUtil::IsTuple(*result_shape)) {
+ // We must record element buffers of tuples as well to avoid leaks.
+ DCHECK(!ShapeUtil::IsNestedTuple(*result_shape));
+ TF_ASSIGN_OR_RETURN(
+ std::vector<se::DeviceMemoryBase> element_buffers,
+ backend_->transfer_manager()->ShallowCopyTupleFromDevice(
+ backend_->default_stream_executor(), result, *result_shape));
+
+ // A tuple may contain the same buffer in more than one element. Keep track
+ // of the buffers already added to avoid duplicates in allocations_.
+ std::set<void*> added_opaques;
+ for (auto element_buffer : element_buffers) {
+ if (added_opaques.count(element_buffer.opaque()) == 0) {
+ added_opaques.insert(element_buffer.opaque());
+ allocations_.push_back(element_buffer);
+ }
+ }
+ }
+
+ return result;
+}
+
+se::DeviceMemoryBase HloTestBase::TransferToDevice(const Literal& literal) {
+ // Allocate memory on the device using the stream executor.
+ int64 allocation_size =
+ backend_->transfer_manager()->GetByteSizeRequirement(literal.shape());
+ se::DeviceMemoryBase allocation =
+ backend_->default_stream_executor()->AllocateArray<uint8>(
+ allocation_size);
+ allocations_.push_back(allocation);
+
+ TF_CHECK_OK(backend_->transfer_manager()->TransferLiteralToDevice(
+ backend_->default_stream_executor(), literal, &allocation));
+
+ return allocation;
+}
+
+std::unique_ptr<Literal> HloTestBase::TransferFromDevice(
+ const Shape& shape, se::DeviceMemoryBase device_base) {
+ auto literal = MakeUnique<Literal>();
+ TF_CHECK_OK(backend_->transfer_manager()->TransferLiteralFromDevice(
+ backend_->default_stream_executor(), device_base, shape, shape,
+ literal.get()));
+ return literal;
+}
+
+std::unique_ptr<Literal> HloTestBase::ExecuteAndTransfer(
+ std::unique_ptr<HloModule> module,
+ tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
+ Shape result_shape;
+ se::DeviceMemoryBase device_base =
+ Execute(std::move(module), arguments, &result_shape).ValueOrDie();
+ return TransferFromDevice(result_shape, device_base);
+}
+
+std::unique_ptr<Literal> HloTestBase::ExecuteAndTransfer(
+ std::unique_ptr<HloModule> module,
+ std::unique_ptr<HloModuleConfig> module_config,
+ tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
+ Shape result_shape;
+ se::DeviceMemoryBase device_base =
+ Execute(std::move(module), std::move(module_config), arguments,
+ &result_shape)
+ .ValueOrDie();
+ return TransferFromDevice(result_shape, device_base);
+}
+
+ProgramShape HloTestBase::MakeProgramShape(HloComputation* computation) {
+ ProgramShape program_shape;
+ for (int64 i = 0; i < computation->num_parameters(); ++i) {
+ *program_shape.add_parameters() =
+ computation->parameter_instruction(i)->shape();
+ }
+ *program_shape.mutable_result() = computation->root_instruction()->shape();
+ return program_shape;
+}
+
+string HloTestBase::TestName() const {
+ return ::testing::UnitTest::GetInstance()->current_test_info()->name();
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
new file mode 100644
index 0000000000..fa88c76899
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -0,0 +1,107 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_
+#define TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+
+// A base class for tests which build and run HLO code. This is a lower level of
+// abstraction than using the client interface and enables, for one, explicitly
+// building a graph of HLO instructions to run.
+class HloTestBase : public ::testing::Test {
+ protected:
+ struct EigenThreadPoolWrapper;
+ HloTestBase();
+
+ ~HloTestBase() override;
+
+ // Executes the given module and returns a global data handle.
+ StatusOr<perftools::gputools::DeviceMemoryBase> Execute(
+ std::unique_ptr<HloModule> module,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments,
+ Shape* result_shape);
+
+ // Variation of Execute which takes a custom module_config instead of creating
+ // a default one.
+ StatusOr<perftools::gputools::DeviceMemoryBase> Execute(
+ std::unique_ptr<HloModule> module,
+ std::unique_ptr<HloModuleConfig> module_config,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments,
+ Shape* result_shape);
+
+ // Transfers the given literal to the device and returns the data handle.
+ perftools::gputools::DeviceMemoryBase TransferToDevice(
+ const Literal& literal);
+
+ // Transfers the array refered to by the given handle from the device and
+ // returns as a Literal.
+ std::unique_ptr<Literal> TransferFromDevice(
+ const Shape& shape, perftools::gputools::DeviceMemoryBase device_base);
+
+ // Executes the given module and return the result as a Literal.
+ std::unique_ptr<Literal> ExecuteAndTransfer(
+ std::unique_ptr<HloModule> module,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments);
+
+ // Variation of ExecuteAndTransfer which takes a custom module_config instead
+ // of creating a default one.
+ std::unique_ptr<Literal> ExecuteAndTransfer(
+ std::unique_ptr<HloModule> module,
+ std::unique_ptr<HloModuleConfig> module_config,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments);
+
+ // Utility function which creates a ProgramShape for a given computation.
+ ProgramShape MakeProgramShape(HloComputation* computation);
+
+ string TestName() const;
+
+ std::unique_ptr<Backend> backend_;
+
+ Compiler::HloDumper test_hlo_dumper_;
+
+ // This vector contains handles of all the device memory allocations performed
+ // by the test. These are deallocated on destruction of the test object.
+ std::vector<perftools::gputools::DeviceMemoryBase> allocations_;
+
+ ErrorSpec error_spec_{0.0001};
+
+ std::unique_ptr<EigenThreadPoolWrapper> thread_pool_wrapper_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_
diff --git a/tensorflow/compiler/xla/tests/inprocess_service_test.cc b/tensorflow/compiler/xla/tests/inprocess_service_test.cc
new file mode 100644
index 0000000000..9909f041de
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/inprocess_service_test.cc
@@ -0,0 +1,204 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <initializer_list>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+// Tests which exercise the "InProcess" methods of xla::Client. The
+// "InProcess" methods require that the client and server share the same
+// process.
+class InProcessServiceTest : public ClientLibraryTestBase {
+ protected:
+ std::unique_ptr<GlobalData> ExecuteR2F32Constant(
+ std::initializer_list<std::initializer_list<float>> values,
+ tensorflow::gtl::ArraySlice<int64> minor_to_major) {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR2<float>(values);
+ auto computation = builder.Build().ConsumeValueOrDie();
+ CHECK_EQ(2, minor_to_major.size());
+ Shape shape_with_layout = ShapeUtil::MakeShapeWithLayout(
+ F32,
+ /*dimensions=*/{static_cast<int64>(values.size()),
+ static_cast<int64>(values.begin()->size())},
+ minor_to_major);
+ return client_
+ ->Execute(computation, {}, &shape_with_layout,
+ /*execution_profile=*/nullptr)
+ .ConsumeValueOrDie();
+ }
+
+ ErrorSpec error_spec_{0.0001};
+};
+
+XLA_TEST_F(InProcessServiceTest, TransferFromServer) {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR1<int32>({1, 42, 5});
+ auto computation = builder.Build().ConsumeValueOrDie();
+
+ auto handle = client_->Execute(computation, {}).ConsumeValueOrDie();
+
+ std::vector<int32> result(3, 0);
+ ASSERT_IS_OK(client_->TransferInProcess(*handle, result.data()));
+ EXPECT_MATCH(result, testing::VectorMatcher<int32>({1, 42, 5}));
+}
+
+XLA_TEST_F(InProcessServiceTest, TransferToServer) {
+ std::vector<float> input{1.0f, 2.0f, -42.0f};
+ Shape shape = ShapeUtil::MakeShape(F32, {3});
+ auto data_handle = client_->TransferToServerInProcess(shape, input.data())
+ .ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto param = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "param");
+ builder.Add(param, param);
+
+ ComputeAndCompareR1<float>(&builder, {2.0f, 4.0f, -84.0f},
+ {data_handle.get()}, error_spec_);
+}
+
+// TODO(b/28506710): This test case seems not to test inprocess
+// methods.
+TEST_F(InProcessServiceTest, GetShape) {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR1<int32>({1, 42, 5});
+ auto computation = builder.Build().ConsumeValueOrDie();
+
+ auto handle = client_->Execute(computation, {}).ConsumeValueOrDie();
+
+ Shape shape = client_->GetShape(*handle).ConsumeValueOrDie();
+ ASSERT_EQ(S32, shape.element_type());
+ ASSERT_EQ(1, ShapeUtil::Rank(shape));
+ ASSERT_EQ(3, shape.dimensions(0));
+}
+
+XLA_TEST_F(InProcessServiceTest, GetShapeOfClientSuppliedArrayRowMajor) {
+ std::vector<float> input{1.0f, 2.0f, 3.0f, 4.0f};
+ Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
+ shape.clear_layout();
+ *shape.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
+ auto handle = client_->TransferToServerInProcess(shape, input.data())
+ .ConsumeValueOrDie();
+
+ Shape shape_returned = client_->GetShape(*handle).ConsumeValueOrDie();
+ ASSERT_TRUE(ShapeUtil::Equal(shape, shape_returned));
+}
+
+XLA_TEST_F(InProcessServiceTest, GetShapeOfClientSuppliedArrayColMajor) {
+ std::vector<float> input{1.0f, 2.0f, 3.0f, 4.0f};
+ Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
+ shape.clear_layout();
+ *shape.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+ auto handle = client_->TransferToServerInProcess(shape, input.data())
+ .ConsumeValueOrDie();
+
+ Shape shape_returned = client_->GetShape(*handle).ConsumeValueOrDie();
+ ASSERT_TRUE(ShapeUtil::Equal(shape, shape_returned));
+}
+
+TEST_F(InProcessServiceTest, TransferToServerNoLayout) {
+ std::vector<float> input{1.0f, 2.0f, -42.0f};
+ Shape shape = ShapeUtil::MakeShape(F32, {3});
+ shape.clear_layout();
+ auto transfer_status =
+ client_->TransferToServerInProcess(shape, input.data());
+ ASSERT_EQ(transfer_status.status().code(),
+ tensorflow::error::INVALID_ARGUMENT);
+}
+
+XLA_TEST_F(InProcessServiceTest, ExecuteRowMajor) {
+ auto handle =
+ ExecuteR2F32Constant({{1.0, 2.0}, {3.0, 4.0}}, /*minor_to_major=*/{1, 0});
+
+ std::vector<float> result(4, 0.0);
+ Shape shape;
+ ASSERT_IS_OK(client_->TransferInProcess(*handle, result.data()));
+
+ EXPECT_MATCH(result, testing::VectorMatcher<float>({1.0, 2.0, 3.0, 4.0}));
+}
+
+XLA_TEST_F(InProcessServiceTest, ExecuteColumnMajor) {
+ auto handle =
+ ExecuteR2F32Constant({{1.0, 2.0}, {3.0, 4.0}}, /*minor_to_major=*/{0, 1});
+
+ std::vector<float> result(4, 0);
+ Shape shape;
+ ASSERT_IS_OK(client_->TransferInProcess(*handle, result.data()));
+
+ EXPECT_MATCH(result, testing::VectorMatcher<float>({1.0, 3.0, 2.0, 4.0}));
+}
+
+XLA_TEST_F(InProcessServiceTest, ExecuteAndReuseDifferentLayouts) {
+ // Create arrays on the server which have different layouts. Verify the
+ // computation still produces the correct results.
+ auto handle_rowmaj =
+ ExecuteR2F32Constant({{1.0, 2.0}, {3.0, 4.0}}, /*minor_to_major=*/{1, 0});
+
+ auto handle_colmaj = ExecuteR2F32Constant({{10.0, 20.0}, {30.0, 40.0}},
+ /*minor_to_major=*/{0, 1});
+
+ ComputationBuilder builder(client_, TestName());
+ auto param0 =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "param0");
+ auto param1 =
+ builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "param1");
+ builder.Add(param0, param1);
+
+ Array2D<float> expected({{11.0, 22.0}, {33.0, 44.0}});
+ ComputeAndCompareR2<float>(&builder, expected,
+ {handle_rowmaj.get(), handle_colmaj.get()},
+ error_spec_);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc
new file mode 100644
index 0000000000..f7bbc0f38b
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/literal_test_util.cc
@@ -0,0 +1,566 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+
+#include <unistd.h>
+#include <cmath>
+#include <vector>
+
+#include "tensorflow/compiler/xla/index_util.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/casts.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+/* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected,
+ const Shape& actual) {
+ ASSERT_EQ(ShapeUtil::Rank(expected), ShapeUtil::Rank(actual));
+ ASSERT_EQ(expected.element_type(), actual.element_type())
+ << PrimitiveType_Name(expected.element_type()) << " vs "
+ << PrimitiveType_Name(actual.element_type());
+ ASSERT_EQ(expected.dimensions_size(), actual.dimensions_size());
+ for (int i = 0; i < expected.dimensions_size(); ++i) {
+ ASSERT_EQ(expected.dimensions(i), actual.dimensions(i))
+ << "mismatch in dimension #" << i
+ << " expected: " << ShapeUtil::HumanString(expected)
+ << " actual: " << ShapeUtil::HumanString(actual);
+ }
+ ASSERT_EQ(expected.tuple_shapes_size(), actual.tuple_shapes_size());
+ for (int i = 0; i < expected.tuple_shapes_size(); ++i) {
+ AssertEqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i));
+ }
+}
+
+/* static */ void LiteralTestUtil::AssertEqualShapesAndLayouts(
+ const Shape& expected, const Shape& actual) {
+ ASSERT_EQ(expected.ShortDebugString(), actual.ShortDebugString());
+}
+
+namespace {
+
+string Hostname() {
+ char hostname[1024];
+ gethostname(hostname, sizeof hostname);
+ hostname[sizeof hostname - 1] = 0;
+ return string(hostname);
+}
+
+// Helper function for comparing a floating point type, FloatT, bitwise equal
+// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
+// -- on miscompare, a nice error message is given in the AssertionFailure.
+template <typename FloatT, typename UnsignedT>
+testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
+ auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
+ auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
+ if (ulhs != urhs) {
+ return testing::AssertionFailure() << tensorflow::strings::Printf(
+ "floating values are not bitwise-equal; and equality testing "
+ "was requested: %s=%g=%a vs %s=%g=%a",
+ tensorflow::strings::StrCat(tensorflow::strings::Hex(ulhs))
+ .c_str(),
+ lhs, lhs,
+ tensorflow::strings::StrCat(tensorflow::strings::Hex(urhs))
+ .c_str(),
+ rhs, rhs);
+ }
+ return testing::AssertionSuccess();
+}
+
+// Templated comparator that specializes for float equality comparison with the
+// bitwise helper above (this is the un-specialized fallback, to just use the
+// default gunit implementation).
+template <typename NativeT>
+testing::AssertionResult CompareEqual(NativeT lhs, NativeT rhs) {
+ if (lhs == rhs) {
+ return testing::AssertionSuccess();
+ }
+ ::testing::Message msg;
+ msg << "Expected equality of these values:";
+ msg << "\n " << lhs;
+ msg << "\n " << rhs;
+
+ return testing::AssertionFailure() << msg;
+}
+
+// Specializations for floating types that do bitwise comparisons when equality
+// comparison is requested.
+template <>
+testing::AssertionResult CompareEqual<float>(float lhs, float rhs) {
+ return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs);
+}
+template <>
+testing::AssertionResult CompareEqual<double>(double lhs, double rhs) {
+ return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs);
+}
+
+// A recursive function which iterates through every index of expected and
+// actual literal and compares their values elementwise. Returns true if all
+// elements are equal.
+template <typename NativeT>
+bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
+ tensorflow::gtl::MutableArraySlice<int64> multi_index,
+ int64 dimension) {
+ if (dimension == expected.shape().dimensions_size()) {
+ NativeT expected_value = LiteralUtil::Get<NativeT>(expected, multi_index);
+ NativeT actual_value = LiteralUtil::Get<NativeT>(actual, multi_index);
+ testing::AssertionResult result =
+ CompareEqual<NativeT>(expected_value, actual_value);
+ return result; // Defines implicit coersion to bool.
+ }
+
+ bool all_match = true;
+ for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
+ multi_index[dimension] = i;
+ all_match = all_match && ExpectLiteralsEqual<NativeT>(
+ expected, actual, multi_index, dimension + 1);
+ }
+ return all_match;
+}
+
+} // namespace
+
+/* static */ void LiteralTestUtil::ExpectEqual(const Literal& expected,
+ const Literal& actual) {
+ EXPECT_TRUE(Equal(expected, actual)) << "expected:\n"
+ << LiteralUtil::ToString(expected)
+ << "\n\tvs actual:\n"
+ << LiteralUtil::ToString(actual);
+}
+
+/* static */ void LiteralTestUtil::ExpectNotEqual(const Literal& expected,
+ const Literal& actual) {
+ EXPECT_FALSE(Equal(expected, actual));
+}
+
+/* static */ testing::AssertionResult LiteralTestUtil::Equal(
+ const Literal& expected, const Literal& actual) {
+ VLOG(1) << "expected: " << LiteralUtil::ToString(expected);
+ VLOG(1) << "actual: " << LiteralUtil::ToString(actual);
+
+ AssertEqualShapes(expected.shape(), actual.shape());
+ std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
+ bool match = false;
+ switch (expected.shape().element_type()) {
+ case PRED:
+ match = ExpectLiteralsEqual<bool>(expected, actual, &multi_index, 0);
+ break;
+ case U8:
+ match = ExpectLiteralsEqual<uint8>(expected, actual, &multi_index, 0);
+ break;
+ case S32:
+ match = ExpectLiteralsEqual<int32>(expected, actual, &multi_index, 0);
+ break;
+ case S64:
+ match = ExpectLiteralsEqual<int64>(expected, actual, &multi_index, 0);
+ break;
+ case U32:
+ match = ExpectLiteralsEqual<uint32>(expected, actual, &multi_index, 0);
+ break;
+ case U64:
+ match = ExpectLiteralsEqual<uint64>(expected, actual, &multi_index, 0);
+ break;
+ case F32:
+ match = ExpectLiteralsEqual<float>(expected, actual, &multi_index, 0);
+ break;
+ case F64:
+ match = ExpectLiteralsEqual<double>(expected, actual, &multi_index, 0);
+ break;
+ case TUPLE: {
+ bool tuple_match = true;
+ for (int i = 0; i < actual.tuple_literals_size(); ++i) {
+ auto result =
+ Equal(expected.tuple_literals(i), actual.tuple_literals(i));
+ tuple_match = tuple_match ? !!result : false;
+ }
+ match = tuple_match;
+ break;
+ }
+ default:
+ LOG(FATAL)
+ << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: "
+ << PrimitiveType_Name(expected.shape().element_type());
+ }
+ testing::AssertionResult result = testing::AssertionSuccess();
+ if (!match) {
+ result = testing::AssertionFailure()
+ << "expected: " << LiteralUtil::ToString(expected)
+ << "\nactual: " << LiteralUtil::ToString(actual);
+ VLOG(1) << result.message();
+ }
+ return result;
+}
+
+/* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected,
+ const Literal& actual) {
+ VLOG(1) << "expected: " << LiteralUtil::ToString(expected);
+ VLOG(1) << "actual: " << LiteralUtil::ToString(actual);
+
+ ASSERT_TRUE(ShapeUtil::IsTuple(expected.shape()));
+ ASSERT_TRUE(ShapeUtil::IsTuple(actual.shape()));
+ AssertEqualShapes(expected.shape(), actual.shape());
+ for (uint64 i = 0; i < expected.tuple_literals_size(); ++i) {
+ const auto& expected_element = expected.tuple_literals(i);
+ const auto& actual_element = actual.tuple_literals(i);
+ if (ShapeUtil::IsTuple(expected_element.shape())) {
+ ExpectEqualTuple(expected_element, actual_element);
+ } else {
+ ExpectEqual(expected_element, actual_element);
+ }
+ }
+}
+
+namespace {
+
+// Helper class for comparing floating-point literals within an error bound.
+class NearComparator {
+ public:
+ explicit NearComparator(ErrorSpec error) : error_(error) {}
+
+ // Compares the two literals elementwise. EXPECTs each pair of elements to be
+ // within the error bound. Emits useful log messages and dumps literals to
+ // temporary files on failure. Returns true if literals match.
+ bool ExpectNear(const Literal& expected, const Literal& actual) {
+ VLOG(1) << "expected: " << LiteralUtil::ToString(expected);
+ VLOG(1) << "actual: " << LiteralUtil::ToString(actual);
+
+ LiteralTestUtil::AssertEqualShapes(expected.shape(), actual.shape());
+
+ // Set up members used during the comparison.
+ num_miscompares_ = 0;
+ abs_diff_sum_ = 0.0;
+ abs_expected_sum_ = 0.0;
+ abs_diff_miscompare_sum_ = 0.0;
+ abs_expected_miscompare_sum_ = 0.0;
+ max_rel_err_ = 0.0;
+ max_abs_err_ = 0.0;
+ *miscompares_.mutable_shape() =
+ ShapeUtil::ChangeElementType(actual.shape(), PRED);
+ miscompares_.mutable_preds()->Resize(
+ ShapeUtil::ElementsIn(miscompares_.shape()), false);
+ multi_index_.resize(expected.shape().dimensions_size(), 0);
+
+ switch (expected.shape().element_type()) {
+ case F32:
+ ExpectLiteralsNear<float>(expected, actual, 0);
+ break;
+ case F64:
+ ExpectLiteralsNear<double>(expected, actual, 0);
+ break;
+ default:
+ LOG(FATAL) << "Unsupported primitive type in near comparator: "
+ << PrimitiveType_Name(expected.shape().element_type())
+ << ". Must be floating-point type.";
+ }
+
+ if (num_miscompares_ > 0) {
+ if (!VLOG_IS_ON(1)) {
+ LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape())
+ << " " << LiteralUtil::ToString(expected);
+ LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape())
+ << " " << LiteralUtil::ToString(actual);
+ }
+ EXPECT_TRUE(num_miscompares_ == 0)
+ << "\nmax relative mismatch at index "
+ << LiteralTestUtil::MultiIndexAsString(max_rel_multi_index_)
+ << "\nmaximum relative error " << max_rel_err_
+ << "\nmax absolute mismatch at index "
+ << LiteralTestUtil::MultiIndexAsString(max_abs_multi_index_)
+ << "\nmaximum absolute error " << max_abs_err_
+ << "\nfirst mismatch at index "
+ << LiteralTestUtil::MultiIndexAsString(first_multi_index_)
+ << "\nlast mismatch at index "
+ << LiteralTestUtil::MultiIndexAsString(last_multi_index_)
+ << "\ntotal absolute error " << abs_diff_sum_
+ << "\ntotal absolute error of miscompares "
+ << abs_diff_miscompare_sum_ << "\ntotal relative error "
+ << (abs_diff_sum_ / abs_expected_sum_)
+ << "\ntotal relative error of miscompares "
+ << (abs_diff_miscompare_sum_ / abs_expected_miscompare_sum_)
+ << "\nfailure count " << num_miscompares_;
+
+ WriteLiteralToTempFile(expected, "expected");
+ WriteLiteralToTempFile(actual, "actual");
+ WriteLiteralToTempFile(miscompares_, "miscompares");
+ }
+ return num_miscompares_ == 0;
+ }
+
+ private:
+ // EXPECTs that the two given scalar values are within the error bound. Keeps
+ // track of how many mismatches have occured to keep the size of the output
+ // manageable.
+ template <typename NativeT>
+ bool ExpectValuesNear(NativeT expected, NativeT actual) {
+ if (expected == actual) {
+ return true;
+ }
+
+ float abs_diff = std::abs(actual - expected);
+ float rel_err = abs_diff / std::abs(expected);
+ abs_diff_sum_ += abs_diff;
+ abs_expected_sum_ += std::abs(expected);
+ if (rel_err > max_rel_err_) {
+ max_rel_err_ = rel_err;
+ max_rel_multi_index_ = multi_index_;
+ }
+ if (abs_diff > max_abs_err_) {
+ max_abs_err_ = abs_diff;
+ max_abs_multi_index_ = multi_index_;
+ }
+ VLOG(10) << tensorflow::strings::Printf(
+ "index %s abs_diff %f rel_err %f",
+ LiteralTestUtil::MultiIndexAsString(multi_index_).c_str(), abs_diff,
+ rel_err);
+ bool nan_mismatch = std::isnan(actual) != std::isnan(expected);
+ bool mismatch =
+ (nan_mismatch || (abs_diff >= error_.abs && rel_err >= error_.rel));
+ if (mismatch) {
+ abs_diff_miscompare_sum_ += abs_diff;
+ abs_expected_miscompare_sum_ += std::abs(expected);
+ const int64 kMaxFailures = 2;
+ if (num_miscompares_ < kMaxFailures) {
+ EXPECT_NEAR(expected, actual, error_.abs)
+ << "mismatch at index "
+ << LiteralTestUtil::MultiIndexAsString(multi_index_) << " abs diff "
+ << abs_diff << " rel err " << rel_err << " failure #"
+ << num_miscompares_;
+ } else if (num_miscompares_ == kMaxFailures) {
+ LOG(ERROR)
+ << "reached max 'loud' failure count; silently proceeding...";
+ }
+ if (num_miscompares_ == 0) {
+ first_multi_index_ = multi_index_;
+ }
+ num_miscompares_++;
+ last_multi_index_ = multi_index_;
+ }
+ return !mismatch;
+ }
+
+ // Recursive function which compares the two given literals elementwise.
+ template <typename NativeT>
+ void ExpectLiteralsNear(const Literal& expected, const Literal& actual,
+ int64 dimension) {
+ if (dimension == expected.shape().dimensions_size()) {
+ bool near =
+ ExpectValuesNear(LiteralUtil::Get<NativeT>(expected, multi_index_),
+ LiteralUtil::Get<NativeT>(actual, multi_index_));
+ LiteralUtil::Set<bool>(&miscompares_, multi_index_, !near);
+ } else {
+ for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
+ multi_index_[dimension] = i;
+ ExpectLiteralsNear<NativeT>(expected, actual, dimension + 1);
+ }
+ }
+ }
+
+ // Writes the given literal to a file in the test temporary directory.
+ void WriteLiteralToTempFile(const Literal& literal, const string& name) {
+ int64 now_usec = tensorflow::Env::Default()->NowMicros();
+ string filename = tensorflow::io::JoinPath(
+ tensorflow::testing::TmpDir(),
+ tensorflow::strings::Printf("tempfile-%s-%llx-%s", Hostname().c_str(),
+ now_usec, name.c_str()));
+ TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(),
+ filename, literal));
+ LOG(ERROR) << "wrote to " << name << " file: " << filename;
+ }
+
+ ErrorSpec error_;
+
+ // Number of element miscomparisons encountered so far.
+ int64 num_miscompares_;
+
+ // A Literal containing which elements did not match in the expected and
+ // actual literals. miscompares_ contains PREDs and is of the same sizes as
+ // the comparison literals.
+ Literal miscompares_;
+
+ // A multidimensional index used when performing the recursive comparison.
+ std::vector<int64> multi_index_;
+
+ // Aggregated Statistics on input.
+ double abs_diff_sum_;
+ double abs_expected_sum_;
+ double abs_diff_miscompare_sum_;
+ double abs_expected_miscompare_sum_;
+ float max_rel_err_;
+ float max_abs_err_;
+ std::vector<int64> first_multi_index_;
+ std::vector<int64> last_multi_index_;
+ std::vector<int64> max_rel_multi_index_;
+ std::vector<int64> max_abs_multi_index_;
+};
+
+} // namespace
+
+/* static */ testing::AssertionResult LiteralTestUtil::Near(
+ const Literal& expected, const Literal& actual, const ErrorSpec& error) {
+ NearComparator comparator(error);
+ return comparator.ExpectNear(expected, actual)
+ ? testing::AssertionSuccess()
+ : testing::AssertionFailure() << "values were not near";
+}
+
+/* static */ void LiteralTestUtil::ExpectNear(const Literal& expected,
+ const Literal& actual,
+ const ErrorSpec& error) {
+ EXPECT_TRUE(Near(expected, actual, error));
+}
+
+/* static */ testing::AssertionResult LiteralTestUtil::NearTuple(
+ const Literal& expected, const Literal& actual, const ErrorSpec& error) {
+ VLOG(1) << "expected: " << LiteralUtil::ToString(expected);
+ VLOG(1) << "actual: " << LiteralUtil::ToString(actual);
+
+ if (!ShapeUtil::IsTuple(expected.shape()) ||
+ !ShapeUtil::IsTuple(actual.shape())) {
+ return testing::AssertionFailure()
+ << "tuples expected expected shape = "
+ << expected.shape().ShortDebugString()
+ << " actual shape = " << actual.shape().ShortDebugString();
+ }
+ AssertEqualShapes(expected.shape(), actual.shape());
+ for (uint64 i = 0; i < expected.tuple_literals_size(); ++i) {
+ const auto& expected_element = expected.tuple_literals(i);
+ const auto& actual_element = actual.tuple_literals(i);
+ if (ShapeUtil::IsTuple(expected_element.shape())) {
+ auto ret = NearTuple(expected_element, actual_element, error);
+ if (!ret) {
+ return ret;
+ }
+ } else if (ShapeUtil::ElementIsFloating(expected_element.shape())) {
+ auto ret = Near(expected_element, actual_element, error);
+ if (!ret) {
+ return ret;
+ }
+ } else {
+ auto ret = Equal(expected_element, actual_element);
+ if (!ret) {
+ return ret;
+ }
+ }
+ }
+
+ return testing::AssertionSuccess();
+}
+
+/* static */ void LiteralTestUtil::ExpectNearTuple(const Literal& expected,
+ const Literal& actual,
+ const ErrorSpec& error) {
+ EXPECT_TRUE(NearTuple(expected, actual, error));
+}
+
+/* static */ string LiteralTestUtil::MultiIndexAsString(
+ tensorflow::gtl::ArraySlice<int64> multi_index) {
+ return tensorflow::strings::StrCat(
+ "{", tensorflow::str_util::Join(multi_index, ","), "}");
+}
+
+/* static */ std::unique_ptr<Literal> LiteralTestUtil::Reshape(
+ tensorflow::gtl::ArraySlice<int64> new_dimensions,
+ tensorflow::gtl::ArraySlice<int64> minor_to_major, const Literal& literal) {
+ int64 new_num_elements = 1;
+ for (int64 i = 0; i < new_dimensions.size(); ++i) {
+ new_num_elements *= new_dimensions[i];
+ }
+ CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
+
+ auto new_literal = MakeUnique<Literal>();
+ *new_literal->mutable_shape() =
+ ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions);
+
+ // Create a new shape with the given minor-to-major layout. This shape is used
+ // solely for converting linear address to multi-dimensional addresses when
+ // writing elements to the new literal.
+ Shape shape_with_layout = new_literal->shape();
+ *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
+
+ // Allocate space in the new literal.
+ LiteralUtil::Reserve(ShapeUtil::ElementsIn(literal.shape()),
+ new_literal.get());
+
+ // Copy data into new literal, element-by-element.
+ for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
+ std::vector<int64> from_multi_index =
+ IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i);
+ std::vector<int64> to_multi_index =
+ IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
+ switch (literal.shape().element_type()) {
+ case PRED:
+ LiteralUtil::Set<bool>(
+ new_literal.get(), to_multi_index,
+ LiteralUtil::Get<bool>(literal, from_multi_index));
+ break;
+ case U8:
+ LiteralUtil::Set<uint8>(
+ new_literal.get(), to_multi_index,
+ LiteralUtil::Get<uint8>(literal, from_multi_index));
+ break;
+ case U32:
+ LiteralUtil::Set<uint32>(
+ new_literal.get(), to_multi_index,
+ LiteralUtil::Get<uint32>(literal, from_multi_index));
+ break;
+ case S32:
+ LiteralUtil::Set<int32>(
+ new_literal.get(), to_multi_index,
+ LiteralUtil::Get<int32>(literal, from_multi_index));
+ break;
+ case U64:
+ LiteralUtil::Set<uint64>(
+ new_literal.get(), to_multi_index,
+ LiteralUtil::Get<uint64>(literal, from_multi_index));
+ break;
+ case S64:
+ LiteralUtil::Set<int64>(
+ new_literal.get(), to_multi_index,
+ LiteralUtil::Get<int64>(literal, from_multi_index));
+ break;
+ case F32:
+ LiteralUtil::Set<float>(
+ new_literal.get(), to_multi_index,
+ LiteralUtil::Get<float>(literal, from_multi_index));
+ break;
+ case F64:
+ LiteralUtil::Set<double>(
+ new_literal.get(), to_multi_index,
+ LiteralUtil::Get<double>(literal, from_multi_index));
+ break;
+ default:
+ LOG(FATAL) << "Unhandled primitive element type: "
+ << PrimitiveType_Name(literal.shape().element_type());
+ }
+ }
+
+ return new_literal;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h
new file mode 100644
index 0000000000..85656a53e4
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/literal_test_util.h
@@ -0,0 +1,274 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_
+
+#include <initializer_list>
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Structure describing permissible absolute and relative error bounds.
+struct ErrorSpec {
+ explicit ErrorSpec(float aabs, float arel = 0) : abs(aabs), rel(arel) {}
+
+ float abs; // Absolute error bound.
+ float rel; // Relative error bound.
+};
+
+// Utility class for making expectations/assertions related to XLA literals.
+class LiteralTestUtil {
+ public:
+ // Asserts that the given shapes have the same rank, dimension sizes, and
+ // primitive types.
+ static void AssertEqualShapes(const Shape& expected, const Shape& actual);
+
+ // Asserts that the provided shapes are equal as defined in AssertEqualShapes
+ // and that they have the same layout.
+ static void AssertEqualShapesAndLayouts(const Shape& expected,
+ const Shape& actual);
+
+ // Asserts that the expected and actual literals are (bitwise) equal for all
+ // elements in the literal. Also, asserts that the rank, dimensions sizes, and
+ // primitive type are equal.
+ static testing::AssertionResult Equal(
+ const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT;
+
+ // Expects that expected and actual are Equal.
+ static void ExpectEqual(const Literal& expected, const Literal& actual);
+
+ // Expects that expected and actual are Not Equal.
+ static void ExpectNotEqual(const Literal& expected, const Literal& actual);
+
+ // Asserts the given literal are (bitwise) equal to given expected values.
+ template <typename NativeT>
+ static void ExpectR0Equal(NativeT expected, const Literal& actual);
+ template <typename NativeT>
+ static void ExpectR1Equal(tensorflow::gtl::ArraySlice<NativeT> expected,
+ const Literal& actual);
+ template <typename NativeT>
+ static void ExpectR2Equal(
+ std::initializer_list<std::initializer_list<NativeT>> expected,
+ const Literal& actual);
+ template <typename NativeT>
+ static void ExpectR3Equal(
+ std::initializer_list<
+ std::initializer_list<std::initializer_list<NativeT>>>
+ expected,
+ const Literal& actual);
+
+ // Asserts the given literal are (bitwise) equal to given array.
+ template <typename NativeT>
+ static void ExpectR2EqualArray2D(const Array2D<NativeT>& expected,
+ const Literal& actual);
+ template <typename NativeT>
+ static void ExpectR3EqualArray3D(const Array3D<NativeT>& expected,
+ const Literal& actual);
+ template <typename NativeT>
+ static void ExpectR4EqualArray4D(const Array4D<NativeT>& expected,
+ const Literal& actual);
+
+ // Expects that the values of the elements in the expected and actual tuples
+ // are equal. Tuples are matched recursively.
+ static void ExpectEqualTuple(const Literal& expected, const Literal& actual);
+
+ // Asserts that the expected and actual literals are within the given error
+ // bound for all elements. Also, asserts that the rank, dimensions sizes, and
+ // bounds are equivalent. Only supported for floating point values.
+ static testing::AssertionResult Near(
+ const Literal& expected, const Literal& actual,
+ const ErrorSpec& error) TF_MUST_USE_RESULT;
+
+ // Expects expected and actual to be Near with the given error.
+ static void ExpectNear(const Literal& expected, const Literal& actual,
+ const ErrorSpec& error);
+
+ // Asserts the given literal are within the given error bound of the given
+ // expected values. Only supported for floating point values.
+ template <typename NativeT>
+ static void ExpectR0Near(NativeT expected, const Literal& actual,
+ const ErrorSpec& error);
+ template <typename NativeT>
+ static void ExpectR1Near(tensorflow::gtl::ArraySlice<NativeT> expected,
+ const Literal& actual, const ErrorSpec& error);
+ template <typename NativeT>
+ static void ExpectR2Near(
+ std::initializer_list<std::initializer_list<NativeT>> expected,
+ const Literal& actual, const ErrorSpec& error);
+ template <typename NativeT>
+ static void ExpectR3Near(
+ std::initializer_list<
+ std::initializer_list<std::initializer_list<NativeT>>>
+ expected,
+ const Literal& actual, const ErrorSpec& error);
+
+ // Asserts the given literal are within the given error bound to the given
+ // array. Only supported for floating point values.
+ template <typename NativeT>
+ static void ExpectR2NearArray2D(const Array2D<NativeT>& expected,
+ const Literal& actual,
+ const ErrorSpec& error);
+ template <typename NativeT>
+ static void ExpectR3NearArray3D(const Array3D<NativeT>& expected,
+ const Literal& actual,
+ const ErrorSpec& error);
+ template <typename NativeT>
+ static void ExpectR4NearArray4D(const Array4D<NativeT>& expected,
+ const Literal& actual,
+ const ErrorSpec& error);
+
+ // Returns whether the values of the elements in the expected and actual
+ // tuples are within the given error bound. Tuples are matched recursively.
+ // If the elements of the tuple are not floating-point types, the error spec
+ // is ignored and exact equality is checked.
+ static testing::AssertionResult NearTuple(
+ const Literal& expected, const Literal& actual,
+ const ErrorSpec& error) TF_MUST_USE_RESULT;
+
+ // Expects that the expected and actual values are near.
+ static void ExpectNearTuple(const Literal& expected, const Literal& actual,
+ const ErrorSpec& error);
+
+ // Returns a multi-dimensional index as a string. For example: '{7, 8}' will
+ // be returned for a 2-dimensional index with dimension 0 index equal to 7,
+ // dimension 1 equal to 8.
+ static string MultiIndexAsString(
+ tensorflow::gtl::ArraySlice<int64> multi_index);
+
+ // Creates a literal with a new shape with the given new dimensions using the
+ // data in the given input literal. For reshaping purposes the (flat) data
+ // buffer of the input literal is assumed to have the given minor_to_major
+ // layout order.
+ static std::unique_ptr<Literal> Reshape(
+ tensorflow::gtl::ArraySlice<int64> new_dimensions,
+ tensorflow::gtl::ArraySlice<int64> minor_to_major,
+ const Literal& literal);
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil);
+};
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected,
+ const Literal& actual) {
+ ExpectEqual(*LiteralUtil::CreateR0<NativeT>(expected), actual);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR1Equal(
+ tensorflow::gtl::ArraySlice<NativeT> expected, const Literal& actual) {
+ ExpectEqual(*LiteralUtil::CreateR1<NativeT>(expected), actual);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR2Equal(
+ std::initializer_list<std::initializer_list<NativeT>> expected,
+ const Literal& actual) {
+ ExpectEqual(*LiteralUtil::CreateR2<NativeT>(expected), actual);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR3Equal(
+ std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
+ expected,
+ const Literal& actual) {
+ ExpectEqual(*LiteralUtil::CreateR3<NativeT>(expected), actual);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR2EqualArray2D(
+ const Array2D<NativeT>& expected, const Literal& actual) {
+ ExpectEqual(*LiteralUtil::CreateR2FromArray2D(expected), actual);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR3EqualArray3D(
+ const Array3D<NativeT>& expected, const Literal& actual) {
+ ExpectEqual(*LiteralUtil::CreateR3FromArray3D(expected), actual);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR4EqualArray4D(
+ const Array4D<NativeT>& expected, const Literal& actual) {
+ ExpectEqual(*LiteralUtil::CreateR4FromArray4D(expected), actual);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected,
+ const Literal& actual,
+ const ErrorSpec& error) {
+ ExpectNear(*LiteralUtil::CreateR0<NativeT>(expected), actual, error);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR1Near(
+ tensorflow::gtl::ArraySlice<NativeT> expected, const Literal& actual,
+ const ErrorSpec& error) {
+ ExpectNear(*LiteralUtil::CreateR1<NativeT>(expected), actual, error);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR2Near(
+ std::initializer_list<std::initializer_list<NativeT>> expected,
+ const Literal& actual, const ErrorSpec& error) {
+ ExpectNear(*LiteralUtil::CreateR2<NativeT>(expected), actual, error);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR3Near(
+ std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
+ expected,
+ const Literal& actual, const ErrorSpec& error) {
+ ExpectNear(*LiteralUtil::CreateR3<NativeT>(expected), actual, error);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR2NearArray2D(
+ const Array2D<NativeT>& expected, const Literal& actual,
+ const ErrorSpec& error) {
+ ExpectNear(*LiteralUtil::CreateR2FromArray2D(expected), actual, error);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR3NearArray3D(
+ const Array3D<NativeT>& expected, const Literal& actual,
+ const ErrorSpec& error) {
+ ExpectNear(*LiteralUtil::CreateR3FromArray3D(expected), actual, error);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR4NearArray4D(
+ const Array4D<NativeT>& expected, const Literal& actual,
+ const ErrorSpec& error) {
+ ExpectNear(*LiteralUtil::CreateR4FromArray4D(expected), actual, error);
+}
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_
diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
new file mode 100644
index 0000000000..fdec11c0e9
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
@@ -0,0 +1,102 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests that our utility functions for dealing with literals are correctly
+// implemented.
+
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) {
+ std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple({
+ LiteralUtil::CreateR0<int32>(42).get(),
+ LiteralUtil::CreateR0<int32>(64).get(),
+ });
+ LiteralTestUtil::ExpectEqual(*literal, *literal);
+}
+
+TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
+ // Implementation note: we have to use a death test here, because you can't
+ // un-fail an assertion failure. The CHECK-failure is death, so we can make a
+ // death assertion.
+ auto unequal_things_are_equal = [] {
+ std::unique_ptr<Literal> lhs = LiteralUtil::MakeTuple({
+ LiteralUtil::CreateR0<int32>(42).get(),
+ LiteralUtil::CreateR0<int32>(64).get(),
+ });
+ std::unique_ptr<Literal> rhs = LiteralUtil::MakeTuple({
+ LiteralUtil::CreateR0<int32>(64).get(),
+ LiteralUtil::CreateR0<int32>(42).get(),
+ });
+ CHECK(LiteralTestUtil::Equal(*lhs, *rhs)) << "LHS and RHS are unequal";
+ };
+ ASSERT_DEATH(unequal_things_are_equal(), "LHS and RHS are unequal");
+}
+
+TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
+ auto dummy_lambda = [] {
+ auto two = LiteralUtil::CreateR0<float>(2);
+ auto four = LiteralUtil::CreateR0<float>(4);
+ ErrorSpec error(0.001);
+ CHECK(LiteralTestUtil::Near(*two, *four, error)) << "two is not near four";
+ };
+
+ tensorflow::Env* env = tensorflow::Env::Default();
+ string pattern =
+ tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "/tempfile-*");
+ std::vector<string> files;
+ TF_CHECK_OK(env->GetMatchingPaths(pattern, &files));
+ for (const auto& f : files) {
+ TF_CHECK_OK(env->DeleteFile(f)) << f;
+ }
+
+ ASSERT_DEATH(dummy_lambda(), "two is not near four");
+
+ // Now check we wrote temporary files to the temporary directory that we can
+ // read.
+ std::vector<string> results;
+ TF_CHECK_OK(env->GetMatchingPaths(pattern, &results));
+
+ LOG(INFO) << "results: [" << tensorflow::str_util::Join(results, ", ") << "]";
+ EXPECT_EQ(3, results.size());
+ for (const string& result : results) {
+ Literal literal;
+ TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result,
+ &literal));
+ if (result.find("expected") != string::npos) {
+ EXPECT_EQ("2", LiteralUtil::ToString(literal));
+ } else if (result.find("actual") != string::npos) {
+ EXPECT_EQ("4", LiteralUtil::ToString(literal));
+ } else if (result.find("miscompares") != string::npos) {
+ EXPECT_EQ("true", LiteralUtil::ToString(literal));
+ } else {
+ FAIL() << "unknown file in temporary directory: " << result;
+ }
+ }
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test.cc b/tensorflow/compiler/xla/tests/local_client_aot_test.cc
new file mode 100644
index 0000000000..591fff338c
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test.cc
@@ -0,0 +1,55 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/executable_run_options.h"
+#include "tensorflow/core/platform/dynamic_annotations.h"
+#include "tensorflow/core/platform/test.h"
+
+class LocalClientAotTest : public ::testing::Test {};
+
+// This is a compiled XLA computation which calls SumStructElements, and then
+// doubles the result.
+extern "C" void SumAndDouble(float* out, xla::ExecutableRunOptions* options,
+ void** parameters, void** temporary_buffers);
+
+// Just some structs with some arbitrary fields used to test the OPAQUE type.
+struct OpaqueData {
+ int field1 : 15;
+ int field2 : 14;
+ int field3 : 3;
+};
+
+// This is the implementation of a custom op which will be called by
+// SumAndDouble.
+extern "C" void SumStructElements(float* out, void** parameters) {
+ TF_ANNOTATE_MEMORY_IS_INITIALIZED(parameters, sizeof(OpaqueData*));
+ const auto* opaque_data = static_cast<OpaqueData*>(parameters[0]);
+ *out = opaque_data->field1 + opaque_data->field2 + opaque_data->field3;
+}
+
+TEST_F(LocalClientAotTest, Constant) {
+ xla::ExecutableRunOptions run_options;
+ OpaqueData opaque_data{100, 20, 3};
+ void* parameters[] = {&opaque_data};
+ float out = 0;
+ float tmp = 0;
+ void* temporary_buffers[] = {&out, &tmp, nullptr};
+ SumAndDouble(&out, &run_options, parameters, temporary_buffers);
+ EXPECT_EQ(out, 246.0f);
+
+ opaque_data = {1, 2, 3};
+ SumAndDouble(&out, &run_options, parameters, temporary_buffers);
+ EXPECT_EQ(out, 12.0f);
+}
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
new file mode 100644
index 0000000000..50e5dec0f6
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
@@ -0,0 +1,111 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This program compiles an XLA program which computes 123 and writes the
+// resulting object file to stdout.
+
+#include <iostream>
+#include <vector>
+
+#include "external/llvm/include/llvm/ADT/Triple.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/logging.h"
+
+using xla::string;
+
+xla::Computation Doubler(xla::Client* client) {
+ xla::ComputationBuilder builder(client, "doubler");
+ auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {});
+ auto x = builder.Parameter(0, r0f32, "x");
+ builder.Mul(x, builder.ConstantR0<float>(2.0));
+ return std::move(builder.Build().ValueOrDie());
+}
+
+int main(int argc, char** argv) {
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+
+ auto client = xla::ClientLibrary::LocalClientOrDie();
+
+ xla::ComputationBuilder builder(client, "aot_test_helper");
+ auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape();
+ auto opaque_param = builder.Parameter(0, opaque_shape, "x");
+ auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {});
+ auto sum = builder.CustomCall("SumStructElements", {opaque_param}, r0f32);
+ builder.Call(Doubler(client), {sum});
+
+ if (argc != 2) {
+ LOG(FATAL) << "local_client_aot_test_helper TARGET_CPU";
+ }
+
+ string triple_string;
+ string target_cpu = argv[1];
+ if (target_cpu == "k8") {
+ triple_string = "x86_64-none-linux-gnu";
+ } else if (target_cpu == "darwin") {
+ triple_string = "x86_64-apple-macosx";
+ } else if (target_cpu == "arm") {
+ triple_string = "aarch64-none-linux-gnu";
+ } else if (target_cpu == "ppc") {
+ triple_string = "powerpc64le-unknown-linux-gnu";
+ } else if (target_cpu == "local") {
+ triple_string = xla::llvm_ir::AsString(llvm::sys::getDefaultTargetTriple());
+ } else {
+ LOG(FATAL) << "unsupported TARGET_CPU: " << target_cpu;
+ }
+
+ llvm::Triple triple(xla::llvm_ir::AsStringRef(triple_string));
+
+ xla::cpu::CpuAotCompilationOptions options(
+ triple_string,
+ /*cpu_name=*/"", /*features=*/"", "SumAndDouble",
+ xla::cpu::CpuAotCompilationOptions::RelocationModel::Static);
+ auto result = xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
+ client
+ ->CompileAheadOfTime(builder.Build().ValueOrDie(),
+ /*argument_layouts=*/{&opaque_shape}, r0f32,
+ options)
+ .ConsumeValueOrDie());
+ // We should have two buffers, one for the result and one temporary buffer,
+ // and both should be float-sized. It's lame to hard-code this, but we need
+ // local_client_aot_test.cc to be able to easily invoke the function.
+ CHECK_EQ(result->result_buffer_index(), 0);
+ CHECK_EQ(result->buffer_sizes().size(), 3);
+ CHECK_EQ(result->buffer_sizes()[0], sizeof(float)); // result buffer
+ CHECK_EQ(result->buffer_sizes()[1], sizeof(float)); // temp buffer
+ CHECK_EQ(result->buffer_sizes()[2], -1);
+ if (triple.isOSBinFormatELF()) {
+ // Check the ELF magic.
+ CHECK_EQ(result->object_file_data()[0], 0x7F);
+ CHECK_EQ(result->object_file_data()[1], 'E');
+ CHECK_EQ(result->object_file_data()[2], 'L');
+ CHECK_EQ(result->object_file_data()[3], 'F');
+ // Check the ELF class.
+ CHECK_EQ(result->object_file_data()[4], triple.isArch32Bit() ? 1 : 2);
+ // Check the ELF endianness: it should be little.
+ CHECK_EQ(result->object_file_data()[5], triple.isLittleEndian() ? 1 : 2);
+ // Check the ELF version: it should be 1.
+ CHECK_EQ(result->object_file_data()[6], 1);
+ }
+
+ const std::vector<char>& object_file_data = result->object_file_data();
+ std::cout.write(object_file_data.data(), object_file_data.size());
+
+ return 0;
+}
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc
new file mode 100644
index 0000000000..5c32ed8895
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc
@@ -0,0 +1,220 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+/* static */ TestAllocator* LocalClientTestBase::allocator_;
+
+StatusOr<perftools::gputools::DeviceMemoryBase> TestAllocator::Allocate(
+ int device_ordinal, uint64 size, bool retry_on_failure) {
+ VLOG(2) << "Allocate(" << device_ordinal << ", " << size << ")";
+ {
+ tensorflow::mutex_lock lock(count_mutex_);
+ allocation_count_++;
+ device_allocation_count_[device_ordinal]++;
+ }
+ return StreamExecutorMemoryAllocator::Allocate(device_ordinal, size);
+}
+
+tensorflow::Status TestAllocator::Deallocate(
+ int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) {
+ VLOG(2) << "Deallocate(" << device_ordinal << ")";
+ {
+ tensorflow::mutex_lock lock(count_mutex_);
+ deallocation_count_++;
+ device_deallocation_count_[device_ordinal]++;
+ }
+ return StreamExecutorMemoryAllocator::Deallocate(device_ordinal, mem);
+}
+
+int64 TestAllocator::allocation_count() const {
+ tensorflow::mutex_lock lock(count_mutex_);
+ return allocation_count_;
+}
+
+int64 TestAllocator::allocation_count(int device_ordinal) const {
+ tensorflow::mutex_lock lock(count_mutex_);
+ auto it = device_allocation_count_.find(device_ordinal);
+ if (it == device_allocation_count_.end()) {
+ return 0;
+ } else {
+ return it->second;
+ }
+}
+
+int64 TestAllocator::deallocation_count() const {
+ tensorflow::mutex_lock lock(count_mutex_);
+ return deallocation_count_;
+}
+
+int64 TestAllocator::deallocation_count(int device_ordinal) const {
+ tensorflow::mutex_lock lock(count_mutex_);
+ auto it = device_deallocation_count_.find(device_ordinal);
+ if (it == device_deallocation_count_.end()) {
+ return 0;
+ } else {
+ return it->second;
+ }
+}
+
+/* static */ TestAllocator* LocalClientTestBase::GetOrCreateAllocator(
+ perftools::gputools::Platform* platform) {
+ if (allocator_ == nullptr) {
+ allocator_ = new TestAllocator(
+ platform == nullptr ? PlatformUtil::GetDefaultPlatform().ValueOrDie()
+ : platform);
+ }
+ return allocator_;
+}
+
+LocalClientTestBase::LocalClientTestBase(
+ perftools::gputools::Platform* platform)
+ : local_client_(
+ ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie()) {
+ stream_executor_ = PlatformUtil::GetStreamExecutors(local_client_->platform())
+ .ValueOrDie()[local_client_->default_device_ordinal()];
+ transfer_manager_ =
+ TransferManager::GetForPlatform(local_client_->platform()).ValueOrDie();
+}
+
+std::unique_ptr<ScopedShapedBuffer>
+LocalClientTestBase::LiteralToScopedShapedBuffer(const Literal& literal) {
+ return LiteralToScopedShapedBuffer(literal,
+ local_client_->default_device_ordinal());
+}
+
+std::unique_ptr<ScopedShapedBuffer>
+LocalClientTestBase::LiteralToScopedShapedBuffer(const Literal& literal,
+ int device_ordinal) {
+ CHECK(!ShapeUtil::IsTuple(literal.shape()));
+ auto scoped_buffer =
+ ScopedShapedBuffer::MakeScopedShapedBuffer(
+ literal.shape(), GetOrCreateAllocator(local_client_->platform()),
+ device_ordinal)
+ .ConsumeValueOrDie();
+ // The creation of the scoped shaped buffer should allocate the buffer.
+ CHECK(!scoped_buffer->buffer(/*index=*/{}).is_null() ||
+ ShapeUtil::HasZeroElements(literal.shape()));
+ TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(
+ stream_executor_, literal, scoped_buffer->mutable_buffer(/*index=*/{})));
+ return scoped_buffer;
+}
+
+void LocalClientTestBase::CopyShapedBufferToLiteral(
+ const ShapedBuffer& shaped_buffer, ShapeIndex* index, Literal* literal) {
+ const Shape& shape = ShapeUtil::GetSubshape(shaped_buffer.shape(), *index);
+ if (ShapeUtil::IsTuple(shape)) {
+ *literal->mutable_shape() = shape;
+ for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
+ Literal* element_literal = literal->add_tuple_literals();
+ index->push_back(i);
+ CopyShapedBufferToLiteral(shaped_buffer, index, element_literal);
+ index->pop_back();
+ }
+ } else {
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralFromDevice(
+ stream_executor_, shaped_buffer.buffer(*index), shape, shape, literal));
+ }
+}
+
+std::unique_ptr<Literal> LocalClientTestBase::ShapedBufferToLiteral(
+ const ShapedBuffer& shaped_buffer) {
+ auto literal = MakeUnique<Literal>();
+ ShapeIndex index;
+ CopyShapedBufferToLiteral(shaped_buffer, &index, literal.get());
+ return literal;
+}
+
+std::unique_ptr<ScopedShapedBuffer>
+LocalClientTestBase::ShapedBufferToScopedShapedBuffer(
+ std::unique_ptr<ShapedBuffer> shaped_buffer,
+ DeviceMemoryAllocator* allocator) {
+ std::unique_ptr<ScopedShapedBuffer> scoped_buffer =
+ ScopedShapedBuffer::MakeScopedShapedBuffer(
+ shaped_buffer->shape(), allocator, shaped_buffer->device_ordinal())
+ .ConsumeValueOrDie();
+ // Deallocate the existing DeviceMemoryBase values in the newly created scoped
+ // buffer and replace them with the values from the shaped buffer.
+ for (perftools::gputools::DeviceMemoryBase& memory_base :
+ *scoped_buffer->mutable_buffers()) {
+ TF_CHECK_OK(
+ allocator->Deallocate(shaped_buffer->device_ordinal(), &memory_base));
+ }
+ *scoped_buffer->mutable_buffers() = shaped_buffer->buffers();
+
+ TF_CHECK_OK(
+ scoped_buffer->mutable_shape_index_to_buffer_entry()
+ ->ForEachMutableElement(
+ [&shaped_buffer](const ShapeIndex& index, bool is_leaf,
+ size_t* buffer_entry) -> ::tensorflow::Status {
+ if (is_leaf) {
+ *buffer_entry =
+ shaped_buffer->shape_index_to_buffer_entry().element(
+ index);
+ }
+ return tensorflow::Status::OK();
+ }));
+ return scoped_buffer;
+}
+
+LocalExecuteOptions LocalClientTestBase::DefaultLocalExecuteOptions() const {
+ return LocalExecuteOptions().set_allocator(
+ GetOrCreateAllocator(local_client_->platform()));
+}
+
+std::unique_ptr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
+ return ExecuteLocally(computation, arguments, DefaultLocalExecuteOptions());
+}
+
+std::unique_ptr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const LocalExecuteOptions& options) {
+ return ShapedBufferToScopedShapedBuffer(
+ local_client_->ExecuteLocally(computation, arguments, options)
+ .ConsumeValueOrDie(),
+ options.allocator());
+}
+
+void LocalClientTestBase::ExecuteLocally(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ ShapedBuffer* result) {
+ ExecuteLocally(computation, arguments, DefaultLocalExecuteOptions(), result);
+}
+
+void LocalClientTestBase::ExecuteLocally(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const LocalExecuteOptions& options, ShapedBuffer* result) {
+ ASSERT_IS_OK(
+ local_client_->ExecuteLocally(computation, arguments, options, result));
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h
new file mode 100644
index 0000000000..62916d50e3
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.h
@@ -0,0 +1,146 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_
+#define TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_
+
+#include <map>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/local_service.h"
+#include "tensorflow/compiler/xla/service/platform_util.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+class TestAllocator : public StreamExecutorMemoryAllocator {
+ public:
+ explicit TestAllocator(perftools::gputools::Platform* platform)
+ : StreamExecutorMemoryAllocator(
+ platform, PlatformUtil::GetStreamExecutors(platform).ValueOrDie()) {
+ }
+
+ StatusOr<perftools::gputools::DeviceMemoryBase> Allocate(
+ int device_ordinal, uint64 size, bool retry_on_failure) override;
+ tensorflow::Status Deallocate(
+ int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) override;
+
+ // Return the number of allocations that have been performed.
+ int64 allocation_count() const;
+ int64 allocation_count(int device_ordinal) const;
+
+ // Return the number of deallocations that have been performed.
+ int64 deallocation_count() const;
+ int64 deallocation_count(int device_ordinal) const;
+
+ private:
+ mutable tensorflow::mutex count_mutex_;
+
+ // Global counts of allocations and deallocations.
+ int64 allocation_count_ GUARDED_BY(count_mutex_) = 0;
+ int64 deallocation_count_ GUARDED_BY(count_mutex_) = 0;
+
+ // Per-device counts of allocations and deallocations.
+ std::map<int, int64> device_allocation_count_ GUARDED_BY(count_mutex_);
+ std::map<int, int64> device_deallocation_count_ GUARDED_BY(count_mutex_);
+};
+
+// A base class for tests which exercise the LocalClient interface.
+class LocalClientTestBase : public ::testing::Test {
+ protected:
+ explicit LocalClientTestBase(
+ perftools::gputools::Platform* platform = nullptr);
+
+ static TestAllocator* GetOrCreateAllocator(
+ perftools::gputools::Platform* platform);
+
+ // Copy the given literal onto the default device and return a
+ // ScopedShapedBuffer.
+ std::unique_ptr<ScopedShapedBuffer> LiteralToScopedShapedBuffer(
+ const Literal& literal);
+ // As above, but copy to a specific device.
+ std::unique_ptr<ScopedShapedBuffer> LiteralToScopedShapedBuffer(
+ const Literal& literal, int device_ordinal);
+
+ // Construct and return a literal containing the array represented by
+ // shaped_buffer.
+ std::unique_ptr<Literal> ShapedBufferToLiteral(
+ const ShapedBuffer& shaped_buffer);
+
+ // Helper for converting a ShapedBuffer into a literal.
+ void CopyShapedBufferToLiteral(const ShapedBuffer& shaped_buffer,
+ ShapeIndex* index, Literal* literal);
+
+ // Execute the given computation on the local client. With and without
+ // options.
+ std::unique_ptr<ScopedShapedBuffer> ExecuteLocally(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
+ std::unique_ptr<ScopedShapedBuffer> ExecuteLocally(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const LocalExecuteOptions& options);
+
+ // Returns a default set of execute options, configured to use allocator_
+ // as the allocator.
+ LocalExecuteOptions DefaultLocalExecuteOptions() const;
+
+ // Overloads which write result into the given buffer.
+ void ExecuteLocally(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ ShapedBuffer* result);
+ void ExecuteLocally(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const LocalExecuteOptions& options, ShapedBuffer* result);
+
+ // Convert a ShapedBuffer into a ScopedShaped buffer so that all buffers are
+ // deallocated when the object is destructed.
+ std::unique_ptr<ScopedShapedBuffer> ShapedBufferToScopedShapedBuffer(
+ std::unique_ptr<ShapedBuffer> shaped_buffer,
+ DeviceMemoryAllocator* allocator);
+
+ string TestName() const {
+ return ::testing::UnitTest::GetInstance()->current_test_info()->name();
+ }
+
+ // The allocator must live as long as the service which lives until the end of
+ // the process, so make the allocator static.
+ static TestAllocator* allocator_;
+
+ perftools::gputools::StreamExecutor* stream_executor_;
+ TransferManager* transfer_manager_;
+
+ LocalClient* local_client_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_
diff --git a/tensorflow/compiler/xla/tests/log_test.cc b/tensorflow/compiler/xla/tests/log_test.cc
new file mode 100644
index 0000000000..b520d89de3
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/log_test.cc
@@ -0,0 +1,75 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cmath>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class LogTest : public ClientLibraryTestBase {};
+
+XLA_TEST_F(LogTest, LogZeroValues) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR3FromArray3D<float>(Array3D<float>(3, 0, 0));
+ builder.Log(x);
+
+ ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 0), {},
+ ErrorSpec(0.0001));
+}
+
+TEST_F(LogTest, LogTenValues) {
+ std::vector<float> input = {-0.0, 1.0, 2.0, -3.0, -4.0,
+ 5.0, 6.0, -7.0, -8.0, 9.0};
+
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<float>(input);
+ builder.Log(x);
+
+ std::vector<float> expected;
+ for (float f : input) {
+ expected.push_back(std::log(f));
+ }
+
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc
new file mode 100644
index 0000000000..014417a205
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/map_test.cc
@@ -0,0 +1,589 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/compiler/xla/xla.pb.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class MapTest : public ClientLibraryTestBase {
+ public:
+ explicit MapTest(perftools::gputools::Platform* platform = nullptr)
+ : ClientLibraryTestBase(platform,
+ /*disabled_pass_names=*/{"algsimp", "inline"}) {}
+
+ // Creates a function that adds its scalar argument with the constant 1.0.
+ //
+ // x {R0F32} ----> (add)
+ // /
+ // 1.0f ---------/
+ Computation CreateAdderToOne() {
+ ComputationBuilder mapped_builder(client_, TestName());
+ auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto one = mapped_builder.ConstantR0<float>(1.0);
+ auto adder_to_one = mapped_builder.Add(x, one);
+ auto computation_status = mapped_builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ return computation_status.ConsumeValueOrDie();
+ }
+
+ Computation CreateMax() {
+ ComputationBuilder b(client_, TestName());
+ auto lhs = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto rhs = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
+ b.Max(lhs, rhs);
+ auto computation_status = b.Build();
+ TF_CHECK_OK(computation_status.status());
+ return computation_status.ConsumeValueOrDie();
+ }
+
+ // Creates a computation that accepts an F32 and returns T(1) (ignoring the
+ // argument).
+ template <class T>
+ Computation CreateScalarOne() {
+ ComputationBuilder mapped_builder(client_, "scalar_one");
+ (void)mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ mapped_builder.ConstantR0<T>(1);
+ auto computation_status = mapped_builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ return computation_status.ConsumeValueOrDie();
+ }
+
+ // Creates a function that multiplies its scalar argument by the constant 2.0
+ //
+ // x {R0F32} ----> (mul)
+ // /
+ // 2.0f ---------/
+ Computation CreateMulByTwo() {
+ ComputationBuilder mapped_builder(client_, TestName());
+ auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto two = mapped_builder.ConstantR0<float>(2.0);
+ auto mul_by_two = mapped_builder.Mul(x, two);
+ auto computation_status = mapped_builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ return computation_status.ConsumeValueOrDie();
+ }
+
+ // Creates a function that adds its scalar argument with the constant 1.0 and
+ // then multiplies by the original element.
+ //
+ // /---------------\
+ // / \
+ // x {R0F32} ----> (add) ----> (mul)
+ // /
+ // 1.0f ---------/
+ Computation CreateAdderToOneTimesItself() {
+ ComputationBuilder mapped_builder(client_, TestName());
+ auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto one = mapped_builder.ConstantR0<float>(1.0);
+ auto adder_to_one = mapped_builder.Add(x, one);
+ auto result = mapped_builder.Mul(x, adder_to_one);
+ auto computation_status = mapped_builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ return computation_status.ConsumeValueOrDie();
+ }
+
+ // Creates a function that takes a single parameter and calls map with
+ // "embedded_computation" on it, and then adds "n" to the result.
+ //
+ // x {R0F32} -----------> (map) ----> (add)
+ // / /
+ // embedded_computation --/ n --/
+ Computation CreateMapPlusN(const Computation& embedded_computation, float n) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto map = builder.Map({x}, embedded_computation);
+ auto constant_n = builder.ConstantR0<float>(n);
+ auto add = builder.Add(map, constant_n);
+ auto computation_status = builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ return computation_status.ConsumeValueOrDie();
+ }
+
+ // Creates a binary function with signature (F32, F32) -> Pred
+ // defined by (x, y) -> x > y.
+ Computation CreateGt() {
+ ComputationBuilder b(client_, "Gt");
+ auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
+ auto gt = b.Gt(x, y);
+ auto computation_status = b.Build();
+ TF_CHECK_OK(computation_status.status());
+ return computation_status.ConsumeValueOrDie();
+ }
+
+ // Creates a function that adds three scalar arguments
+ //
+ // x {R0F32} ----\
+ // \
+ // y {R0F32} ----> (add) ---> (add)
+ // /
+ // z {R0F32} ---------------/
+ Computation CreateTernaryAdder() {
+ ComputationBuilder mapped_builder(client_, "TernaryAdder");
+ auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = mapped_builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
+ auto z = mapped_builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "z");
+ auto xy = mapped_builder.Add(x, y);
+ auto xyz = mapped_builder.Add(xy, z);
+ auto computation_status = mapped_builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ return computation_status.ConsumeValueOrDie();
+ }
+};
+
+TEST_F(MapTest, MapEachElemPlusOneR0) {
+ // Applies lambda (x) (+ x 1)) to an input scalar.
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(42.0);
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto param = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto map = builder.Map({param}, CreateAdderToOne());
+
+ ComputeAndCompareR0<float>(&builder, 43.0, {param0_data.get()},
+ ErrorSpec(0.01f));
+}
+
+XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) {
+ // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0.
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto param = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto map = builder.Map({param}, CreateAdderToOne());
+
+ ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
+ ErrorSpec(0.01f));
+}
+
+TEST_F(MapTest, MapEachElemPlusOneR1S4) {
+ // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4.
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto param = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto map = builder.Map({param}, CreateAdderToOne());
+
+ ComputeAndCompareR1<float>(&builder, {3.2f, 4.3f, 5.4f, 6.5f},
+ {param0_data.get()}, ErrorSpec(0.01f));
+}
+
+TEST_F(MapTest, MapEachF32ElementToS32Constant) {
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto param = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto map = builder.Map({param}, CreateScalarOne<int32>());
+
+ ComputeAndCompareR1<int32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
+}
+
+TEST_F(MapTest, MapEachF32ElementToU32Constant) {
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto param = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto map = builder.Map({param}, CreateScalarOne<uint32>());
+
+ ComputeAndCompareR1<uint32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
+}
+
+TEST_F(MapTest, MapEachElemLongerChainR1) {
+ // Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector.
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto param = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto map = builder.Map({param}, CreateAdderToOneTimesItself());
+
+ ComputeAndCompareR1<float>(
+ &builder, {9.36f, 20.91f, 0.11f, 0.24f, 999000.0f, 65535.75f},
+ {param0_data.get()}, ErrorSpec(0.01f));
+}
+
+XLA_TEST_F(MapTest, MapMultipleMapsR1S0) {
+ // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then
+ // maps (lambda (x) (* x 2)) on the result.
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto param = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto map1 = builder.Map({param}, CreateAdderToOne());
+ auto map2 = builder.Map({map1}, CreateMulByTwo());
+
+ ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
+ ErrorSpec(0.01f));
+}
+
+TEST_F(MapTest, MapMultipleMapsR1S4) {
+ // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4, and then
+ // maps (lambda (x) (* x 2)) on the result.
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto param = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto map1 = builder.Map({param}, CreateAdderToOne());
+ auto map2 = builder.Map({map1}, CreateMulByTwo());
+
+ ComputeAndCompareR1<float>(&builder, {6.4f, 8.6f, 10.8f, 13.0f},
+ {param0_data.get()}, ErrorSpec(0.01f));
+}
+
+TEST_F(MapTest, MapEachElemPlusOneR2) {
+ // Maps (lambda (x) (+ x 1)) onto an input R2F32 vector.
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2<float>(
+ {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto param = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto map = builder.Map({param}, CreateAdderToOne());
+
+ Array2D<float> expected_array(
+ {{14.25f, 15.0f}, {-6.1f, -6.2f}, {-7.8f, 9.8f}});
+ ComputeAndCompareR2<float>(&builder, expected_array, {param0_data.get()},
+ ErrorSpec(0.01f));
+}
+
+XLA_TEST_F(MapTest, ComplexNestedMaps) {
+ // Constructs a complex graph of embedded computations to test the computation
+ // lowering order. Python equivalent:
+ //
+ // embed1 = lambda x: x + 1 # x + 1
+ // embed2 = lambda x: embed1(x) + 2 # x + 3
+ // embed3 = lambda x: embed1(x) + 4 # x + 5
+ // embed4 = lambda x: embed2(x) + embed3(x) # 2x + 8
+ // embed5 = lambda x: embed2(x) + 6 # x + 9
+ // result = embed5(42) + embed4(7) # (42 + 9) + (2 * 7 + 8) = 73
+
+ Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+
+ auto embed1 = CreateAdderToOne();
+ auto embed2 = CreateMapPlusN(embed1, 2.0);
+ auto embed3 = CreateMapPlusN(embed1, 4.0);
+
+ ComputationBuilder embed4_builder(client_, "embed4");
+ auto embed4_param = embed4_builder.Parameter(0, scalar_shape, "x");
+ auto embed4_map_lhs = embed4_builder.Map({embed4_param}, embed2);
+ auto embed4_map_rhs = embed4_builder.Map({embed4_param}, embed3);
+ auto embed4_add = embed4_builder.Add(embed4_map_lhs, embed4_map_rhs);
+ auto embed4_status = embed4_builder.Build();
+ ASSERT_IS_OK(embed4_status.status());
+ auto embed4 = embed4_status.ConsumeValueOrDie();
+
+ auto embed5 = CreateMapPlusN(embed2, 6.0);
+
+ ComputationBuilder builder(client_, TestName());
+ auto constant_42 = builder.ConstantR0<float>(42.0);
+ auto constant_7 = builder.ConstantR0<float>(7.0);
+ auto map_42 = builder.Map({constant_42}, embed5);
+ auto map_7 = builder.Map({constant_7}, embed4);
+ builder.Add(map_42, map_7);
+
+ ComputeAndCompareR0<float>(&builder, 73.0, {}, ErrorSpec(0.01f));
+}
+
+TEST_F(MapTest, VersionedEmbeddedComputation) {
+ // Build a computation X, use it in a map, then add an additional operation to
+ // computation X and use it again in a different map. Verify that the proper
+ // versions of computation X are used in each of the maps.
+
+ // Create a (embedded) computation which adds one to its parameter argument.
+ ComputationBuilder embedded_builder(client_, "EmbeddedComputation");
+ auto param_0 =
+ embedded_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0");
+ auto constant_one = embedded_builder.ConstantR0<float>(1.0);
+ auto adder_to_one = embedded_builder.Add(param_0, constant_one);
+ auto computation_status = embedded_builder.Build();
+ ASSERT_IS_OK(computation_status.status());
+ auto embedded_computation = computation_status.ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto constant_vector = builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
+ auto map_plus_1 = builder.Map({constant_vector}, embedded_computation);
+
+ // Add another Add(1) operation to the existing embedded computation. This
+ // requires using the stub interface because the ComputationBuilder does not
+ // allow modification to the Computation objects after they have been built.
+ BinaryOpRequest request;
+ request.set_binop(BINOP_ADD);
+ *request.mutable_lhs() = adder_to_one;
+ *request.mutable_rhs() = constant_one;
+ OpRequest op_request;
+ *op_request.mutable_computation() = embedded_computation.handle();
+ *op_request.mutable_binary_op_request() = request;
+ OpResponse response;
+ tensorflow::Status s = client_->stub()->Op(&op_request, &response);
+ ASSERT_TRUE(s.ok());
+
+ auto map_plus_2 = builder.Map({map_plus_1}, embedded_computation);
+
+ // The original vector has Add(1) applied to it with a map, followed by
+ // Add(1+1) resulting in a net Add(3).
+ ComputeAndCompareR1<float>(&builder, {4.0, 5.0, 6.0, 7.0}, {},
+ ErrorSpec(0.01f));
+}
+
+TEST_F(MapTest, MapBinaryAdder) {
+ // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors.
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ std::unique_ptr<Literal> param1_literal =
+ LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
+ std::unique_ptr<GlobalData> param1_data =
+ client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+
+ auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
+ auto map =
+ builder.Map({param0, param1}, CreateScalarAddComputation(F32, &builder));
+
+ ComputeAndCompareR1<float>(&builder, {7.3f, 7.7, 4.3f, 0},
+ {param0_data.get(), param1_data.get()},
+ ErrorSpec(0.01f));
+}
+
+// Adds two rank-2 arrays with different layouts. This test exercises a path
+// for Map that used to fail in shape inference (b/28989438).
+XLA_TEST_F(MapTest, AddWithMixedLayouts) {
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ test_utils::CreateR2LiteralWithLayout({{1, 2}, {3, 4}}, {1, 0});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ std::unique_ptr<Literal> param1_literal =
+ test_utils::CreateR2LiteralWithLayout({{10, 20}, {30, 40}}, {0, 1});
+ std::unique_ptr<GlobalData> param1_data =
+ client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+
+ auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
+ auto map =
+ builder.Map({param0, param1}, CreateScalarAddComputation(S32, &builder));
+
+ Array2D<int32> expected(2, 2);
+ expected(0, 0) = 11;
+ expected(0, 1) = 22;
+ expected(1, 0) = 33;
+ expected(1, 1) = 44;
+ ComputeAndCompareR2<int32>(&builder, expected,
+ {param0_data.get(), param1_data.get()});
+}
+
+XLA_TEST_F(MapTest, AddR3_3x0x2) {
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ std::unique_ptr<Literal> param1_literal =
+ LiteralUtil::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
+ std::unique_ptr<GlobalData> param1_data =
+ client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+
+ auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
+ auto map =
+ builder.Map({param0, param1}, CreateScalarAddComputation(S32, &builder));
+
+ ComputeAndCompareR3<int32>(&builder, Array3D<int32>(3, 0, 2),
+ {param0_data.get(), param1_data.get()});
+}
+
+TEST_F(MapTest, MapTernaryAdder) {
+ // Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors.
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ std::unique_ptr<Literal> param1_literal =
+ LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
+ std::unique_ptr<GlobalData> param1_data =
+ client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ std::unique_ptr<Literal> param2_literal =
+ LiteralUtil::CreateR1<float>({-10.0f, -100.0f, -900.0f, -400.0f});
+ std::unique_ptr<GlobalData> param2_data =
+ client_->TransferToServer(*param2_literal).ConsumeValueOrDie();
+
+ auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
+ auto param2 = builder.Parameter(2, param2_literal->shape(), "param2");
+ auto map = builder.Map({param0, param1, param2}, CreateTernaryAdder());
+
+ ComputeAndCompareR1<float>(
+ &builder, {-2.7f, -92.3f, -895.7f, -400.0f},
+ {param0_data.get(), param1_data.get(), param2_data.get()},
+ ErrorSpec(0.01f));
+}
+
+TEST_F(MapTest, MapGt) {
+ // Maps (x,y) -> x > y onto two R1F32 vectors.
+ ComputationBuilder b(client_, TestName());
+ auto gt = CreateGt();
+ b.Map({b.ConstantR1<float>({1, 20}), b.ConstantR1<float>({10, 2})}, gt);
+ ComputeAndCompareR1<bool>(&b, {false, true}, {});
+}
+
+TEST_F(MapTest, NestedBinaryMap) {
+ Computation max_with_square;
+ {
+ // max_with_square(x) = do max(x, x^2) via a map.
+ ComputationBuilder b(client_, "max_with_square");
+ auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ b.Map({x, b.Mul(x, x)}, CreateMax());
+ auto computation_status = b.Build();
+ ASSERT_IS_OK(computation_status.status());
+ max_with_square = computation_status.ConsumeValueOrDie();
+ }
+ ComputationBuilder b(client_, TestName());
+ auto input = b.ConstantR1<float>({0.1f, 0.5f, -0.5f, 1.0f, 2.0f});
+ b.Map({input}, max_with_square);
+ ComputeAndCompareR1<float>(&b, {0.1f, 0.5f, 0.25f, 1.0f, 4.0f}, {});
+}
+
+TEST_F(MapTest, MapOperantionWithBuildError) {
+ // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors but uses an unsupported
+ // type combination (F32 + U16) to test that the error is reported to the
+ // outermost ComputationBuilder.
+ ComputationBuilder builder(client_, TestName());
+
+ auto sub_builder = builder.CreateSubBuilder("ErrorAdd");
+ auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(U16, {}), "y");
+ auto adder = sub_builder->Add(x, y);
+ auto error_add = sub_builder->BuildAndNoteError();
+
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ std::unique_ptr<Literal> param1_literal =
+ LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
+ std::unique_ptr<GlobalData> param1_data =
+ client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+
+ auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
+ auto map = builder.Map({param0, param1}, error_add);
+
+ StatusOr<Computation> computation_status = builder.Build();
+ ASSERT_TRUE(!computation_status.ok());
+ EXPECT_MATCH(computation_status.status().ToString(),
+ testing::HasSubstr("error from: ErrorAdd: binary op with "
+ "different element types: f32[] and u16[]"));
+}
+
+// MapTest disables inline and algsimp. MapTestWithFullOpt runs all
+// optimizations.
+using MapTestWithFullOpt = ClientLibraryTestBase;
+
+// Regression test for b/31466798. The inliner simplifies map(param0, param1,
+// power) to power(param0, param1) without deleting the old subcomputation which
+// is the same as the new entry computation. HloSubcomputationUnification used
+// to have issues with such patterns and maybe invalidate the pointer to entry
+// computation.
+TEST_F(MapTestWithFullOpt, MapScalarPower) {
+ ComputationBuilder builder(client_, TestName());
+
+ auto sub_builder = builder.CreateSubBuilder("power");
+ auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
+ sub_builder->Pow(x, y);
+ auto power = sub_builder->BuildAndNoteError();
+
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(2.0f);
+ std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR0<float>(5.0f);
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> param1_data =
+ client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+
+ auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
+ builder.Map({param0, param1}, power);
+
+ ComputeAndCompareR0<float>(&builder, 32.0f,
+ {param0_data.get(), param1_data.get()},
+ ErrorSpec(0.01f));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
new file mode 100644
index 0000000000..8aa4029440
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
@@ -0,0 +1,179 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class MatOpsSimpleTest : public ClientLibraryTestBase {
+ protected:
+ Computation BuildSum() {
+ // sum(x, y) = x + y
+ ComputationBuilder builder(client_, "sum");
+ auto x_value =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x_value");
+ auto y_value =
+ builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y_value");
+ builder.Add(x_value, y_value);
+ auto computation_status = builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ return computation_status.ConsumeValueOrDie();
+ }
+
+ void TestLinspaceMax(int64 rows, int64 cols) {
+ float from = -128.0, to = 256.0;
+ std::unique_ptr<Array2D<float>> alhs =
+ MakeLinspaceArray2D(from, to, rows, cols);
+ auto arhs = MakeUnique<Array2D<float>>(rows, cols, 1.0);
+
+ ComputationBuilder builder(
+ client_,
+ tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols));
+ auto lhs = builder.ConstantR2FromArray2D<float>(*alhs);
+ auto rhs = builder.ConstantR2FromArray2D<float>(*arhs);
+ auto max = builder.Max(lhs, rhs);
+
+ Array2D<float> aexpected(rows, cols);
+ for (int row = 0; row < rows; ++row) {
+ for (int col = 0; col < cols; ++col) {
+ aexpected(row, col) = std::max((*alhs)(row, col), (*arhs)(row, col));
+ }
+ }
+
+ ComputeAndCompareR2<float>(&builder, aexpected, {}, ErrorSpec(1e-6));
+ }
+};
+
+TEST_F(MatOpsSimpleTest, ExpTwoByTwoValues) {
+ ComputationBuilder builder(client_, "exp_2x2");
+ auto data = builder.ConstantR2<float>({
+ {1.0, 0.0}, // row 0
+ {-1.0, 0.5}, // row 1
+ });
+ builder.Exp(data);
+
+ std::unique_ptr<Literal> expected =
+ LiteralUtil::CreateR2<float>({{2.71828, 1.00000}, // row 0
+ {0.36788, 1.64872}}); // row 1
+
+ ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5));
+}
+
+TEST_F(MatOpsSimpleTest, MapTwoByTwo) {
+ Computation add_half;
+ {
+ // add_half(x) = x + 0.5
+ ComputationBuilder builder(client_, "add_half");
+ auto x_value =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x_value");
+ auto half = builder.ConstantR0<float>(0.5);
+ builder.Add(x_value, half);
+ auto computation_status = builder.Build();
+ ASSERT_IS_OK(computation_status.status());
+ add_half = computation_status.ConsumeValueOrDie();
+ }
+
+ ComputationBuilder builder(client_, "map_2x2");
+ auto data = builder.ConstantR2<float>({
+ {1.0, 0.0}, // row 0
+ {-1.0, 0.5}, // row 1
+ });
+ auto map = builder.Map({data}, add_half);
+
+ std::unique_ptr<Literal> expected =
+ LiteralUtil::CreateR2<float>({{1.5, 0.5}, // row 0
+ {-0.5, 1.0}}); // row 1
+ ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5));
+}
+
+TEST_F(MatOpsSimpleTest, MaxTwoByTwoValues) {
+ ComputationBuilder builder(client_, "max_2x2");
+ auto lhs = builder.ConstantR2<float>({
+ {7.0, 2.0}, // row 0
+ {3.0, -4.0}, // row 1
+ });
+ auto rhs = builder.ConstantR2<float>({
+ {5.0, 6.0}, // row 0
+ {1.0, -8.0}, // row 1
+ });
+ auto max = builder.Max(lhs, rhs);
+
+ std::unique_ptr<Literal> expected =
+ LiteralUtil::CreateR2<float>({{7.0, 6.0}, // row 0
+ {3.0, -4.0}}); // row 1
+ ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6));
+}
+
+TEST_F(MatOpsSimpleTest, Max1x1Linspace) { TestLinspaceMax(1, 1); }
+
+TEST_F(MatOpsSimpleTest, Max2x2Linspace) { TestLinspaceMax(2, 2); }
+
+TEST_F(MatOpsSimpleTest, Max3x3Linspace) { TestLinspaceMax(3, 3); }
+
+TEST_F(MatOpsSimpleTest, Max4x4Linspace) { TestLinspaceMax(4, 4); }
+
+TEST_F(MatOpsSimpleTest, Max6x6Linspace) { TestLinspaceMax(6, 6); }
+
+TEST_F(MatOpsSimpleTest, Max8x8Linspace) { TestLinspaceMax(8, 8); }
+
+TEST_F(MatOpsSimpleTest, Max12x12Linspace) { TestLinspaceMax(12, 12); }
+
+TEST_F(MatOpsSimpleTest, Max16x16Linspace) { TestLinspaceMax(16, 16); }
+
+TEST_F(MatOpsSimpleTest, Max32x8Linspace) { TestLinspaceMax(32, 8); }
+
+TEST_F(MatOpsSimpleTest, Max64x8Linspace) { TestLinspaceMax(64, 8); }
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc
new file mode 100644
index 0000000000..2cd680399b
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc
@@ -0,0 +1,74 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests that slice operations can be performed.
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class SliceTest : public ClientLibraryTestBase {};
+
+XLA_TEST_F(SliceTest, Slice2D) {
+ ComputationBuilder builder(client_, "slice_2d");
+ auto original = builder.ConstantR2<float>(
+ {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}, {10.0, 11.0, 12.0}});
+ builder.Slice(original, {2, 1}, {4, 3});
+
+ Array2D<float> expected({{8.0f, 9.0f}, {11.0f, 12.0f}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
+}
+
+XLA_TEST_F(SliceTest, Slice3D) {
+ ComputationBuilder builder(client_, "slice_3d");
+ Array3D<float> array_3d(
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}});
+ auto original = builder.ConstantR3FromArray3D<float>(array_3d);
+ builder.Slice(original, {0, 0, 1}, {2, 1, 2});
+
+ Array3D<float> expected_3d({{{2.0f}}, {{6.0f}}});
+ ComputeAndCompareR3<float>(&builder, expected_3d, {}, ErrorSpec(0.000001));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc
new file mode 100644
index 0000000000..d3400b432f
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/pad_test.cc
@@ -0,0 +1,420 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class PadTest : public ClientLibraryTestBase {
+ protected:
+ PadTest() {
+ // Initializes the padding configuration used for R4 tests.
+ // Pad only on the dimension 0 {low: 1, high: 0, interior: 2} and
+ // dimension 1 {low: 0, high: 2, interior: 1}.
+ auto dimension0 = r4_padding_on_dim0_dim1_.add_dimensions();
+ dimension0->set_edge_padding_low(1);
+ dimension0->set_edge_padding_high(0);
+ dimension0->set_interior_padding(2);
+ auto dimension1 = r4_padding_on_dim0_dim1_.add_dimensions();
+ dimension1->set_edge_padding_low(0);
+ dimension1->set_edge_padding_high(2);
+ dimension1->set_interior_padding(1);
+ auto dimension2 = r4_padding_on_dim0_dim1_.add_dimensions();
+ dimension2->set_edge_padding_low(0);
+ dimension2->set_edge_padding_high(0);
+ dimension2->set_interior_padding(0);
+ auto dimension3 = r4_padding_on_dim0_dim1_.add_dimensions();
+ dimension3->set_edge_padding_low(0);
+ dimension3->set_edge_padding_high(0);
+ dimension3->set_interior_padding(0);
+ }
+
+ // Padding configuration for R4 that only pads dimension 0 and 1.
+ PaddingConfig r4_padding_on_dim0_dim1_;
+};
+
+// Tests a Pad() with a zero-element input and output.
+XLA_TEST_F(PadTest, Pad1DS0ToS0Array) {
+ ComputationBuilder b(client_, TestName());
+ // Set up the padding configuration {low: 0, high: 0, interior: 0}.
+ PaddingConfig padding_config;
+ auto dimension = padding_config.add_dimensions();
+ dimension->set_edge_padding_low(0);
+ dimension->set_edge_padding_high(0);
+ dimension->set_interior_padding(0);
+
+ b.Pad(b.ConstantR1<float>({}), b.ConstantR0<float>(0.1), padding_config);
+ ComputeAndCompareR1<float>(&b, {}, {}, ErrorSpec(0.0001));
+}
+
+// Tests a Pad() with a zero-element input but a non-zero-element output.
+XLA_TEST_F(PadTest, Pad1DS0ToS5Array) {
+ ComputationBuilder b(client_, TestName());
+ // Set up the padding configuration {low: 3, high: 0, interior: 1}.
+ PaddingConfig padding_config;
+ auto dimension = padding_config.add_dimensions();
+ dimension->set_edge_padding_low(1);
+ dimension->set_edge_padding_high(4);
+ dimension->set_interior_padding(7);
+
+ b.Pad(b.ConstantR1<float>({}), b.ConstantR0<float>(0.1), padding_config);
+ ComputeAndCompareR1<float>(&b, std::vector<float>(5, 0.1), {},
+ ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(PadTest, Pad1DS3Array) {
+ ComputationBuilder b(client_, TestName());
+ // Set up the padding configuration {low: 3, high: 0, interior: 1}.
+ PaddingConfig padding_config;
+ auto dimension = padding_config.add_dimensions();
+ dimension->set_edge_padding_low(3);
+ dimension->set_edge_padding_high(0);
+ dimension->set_interior_padding(1);
+
+ b.Pad(b.ConstantR1<float>({1, 2, 3}), b.ConstantR0<float>(0.1),
+ padding_config);
+ std::vector<float> expected({0.1, 0.1, 0.1, 1, 0.1, 2, 0.1, 3});
+ ComputeAndCompareR1<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(PadTest, Pad4D_2x0x3x2_FloatArray) {
+ ComputationBuilder b(client_, TestName());
+ b.Pad(b.ConstantR4FromArray4D<float>(Array4D<float>(2, 0, 3, 2)),
+ b.ConstantR0<float>(1.5), r4_padding_on_dim0_dim1_);
+ ComputeAndCompareR4<float>(&b, Array4D<float>(5, 2, 3, 2, 1.5f), {},
+ ErrorSpec(0.0001));
+}
+
+TEST_F(PadTest, Pad4DFloat_1x1x3x2_Array) {
+ ComputationBuilder b(client_, TestName());
+ auto input = MakeUnique<Array4D<float>>(1, 1, 3, 2);
+ Array2D<float> input_xy({
+ {1.0f, 2.0f}, // row 0
+ {3.0f, 4.0f}, // row 1
+ {5.0f, 6.0f}, // row 2
+ });
+ input->FillWithYX(input_xy);
+
+ b.Pad(b.ConstantR4FromArray4D<float>(*input), b.ConstantR0<float>(1.5),
+ r4_padding_on_dim0_dim1_);
+
+ auto expected = MakeUnique<Array4D<float>>(2, 3, 3, 2);
+ expected->Fill(1.5);
+ (*expected)(1, 0, 0, 0) = 1.0f;
+ (*expected)(1, 0, 0, 1) = 2.0f;
+ (*expected)(1, 0, 1, 0) = 3.0f;
+ (*expected)(1, 0, 1, 1) = 4.0f;
+ (*expected)(1, 0, 2, 0) = 5.0f;
+ (*expected)(1, 0, 2, 1) = 6.0f;
+ ComputeAndCompareR4<float>(&b, *expected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(PadTest, Pad4DFloatArrayWithInteriorPadding) {
+ ComputationBuilder b(client_, TestName());
+
+ const float pad_value = 1.5f;
+ Array4D<float> input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6});
+ b.Pad(b.ConstantR4FromArray4D<float>(input), b.ConstantR0<float>(pad_value),
+ r4_padding_on_dim0_dim1_);
+
+ auto expected = MakeUnique<Array4D<float>>(8, 5, 1, 1);
+ expected->Fill(pad_value);
+ (*expected)(1, 0, 0, 0) = 1.0f;
+ (*expected)(1, 2, 0, 0) = 2.0f;
+ (*expected)(4, 0, 0, 0) = 3.0f;
+ (*expected)(4, 2, 0, 0) = 4.0f;
+ (*expected)(7, 0, 0, 0) = 5.0f;
+ (*expected)(7, 2, 0, 0) = 6.0f;
+ ComputeAndCompareR4<float>(&b, *expected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(PadTest, Pad4DFloatArrayMinorFirstSmall) {
+ ComputationBuilder b(client_, TestName());
+
+ PaddingConfig padding_config;
+ auto dimension0 = padding_config.add_dimensions();
+ dimension0->set_edge_padding_low(0);
+ dimension0->set_edge_padding_high(0);
+ dimension0->set_interior_padding(0);
+ auto dimension1 = padding_config.add_dimensions();
+ dimension1->set_edge_padding_low(0);
+ dimension1->set_edge_padding_high(0);
+ dimension1->set_interior_padding(0);
+ auto dimension2 = padding_config.add_dimensions();
+ dimension2->set_edge_padding_low(2);
+ dimension2->set_edge_padding_high(1);
+ dimension2->set_interior_padding(0);
+ auto dimension3 = padding_config.add_dimensions();
+ dimension3->set_edge_padding_low(2);
+ dimension3->set_edge_padding_high(3);
+ dimension3->set_interior_padding(0);
+
+ const Layout layout = LayoutUtil::MakeLayout({0, 1, 2, 3});
+
+ const float pad_value = -5.123f;
+ Array4D<float> input_array(1, 1, 2, 3, {1, 2, 3, 4, 5, 6});
+ auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
+ input = LiteralUtil::Relayout(*input, layout);
+
+ b.Pad(b.ConstantLiteral(*input), b.ConstantR0(pad_value), padding_config);
+
+ Array4D<float> expected_array(1, 1, 5, 8);
+ expected_array.Fill(pad_value);
+ expected_array(0, 0, 2, 2) = 1.0f;
+ expected_array(0, 0, 2, 3) = 2.0f;
+ expected_array(0, 0, 2, 4) = 3.0f;
+ expected_array(0, 0, 3, 2) = 4.0f;
+ expected_array(0, 0, 3, 3) = 5.0f;
+ expected_array(0, 0, 3, 4) = 6.0f;
+ ComputeAndCompareR4<float>(&b, expected_array, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(PadTest, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) {
+ ComputationBuilder b(client_, TestName());
+
+ PaddingConfig padding_config;
+ auto dimension0 = padding_config.add_dimensions();
+ dimension0->set_edge_padding_low(0);
+ dimension0->set_edge_padding_high(0);
+ dimension0->set_interior_padding(0);
+ auto dimension1 = padding_config.add_dimensions();
+ dimension1->set_edge_padding_low(0);
+ dimension1->set_edge_padding_high(0);
+ dimension1->set_interior_padding(0);
+ auto dimension2 = padding_config.add_dimensions();
+ dimension2->set_edge_padding_low(2);
+ dimension2->set_edge_padding_high(2);
+ dimension2->set_interior_padding(1);
+ auto dimension3 = padding_config.add_dimensions();
+ dimension3->set_edge_padding_low(2);
+ dimension3->set_edge_padding_high(2);
+ dimension3->set_interior_padding(0);
+
+ const Layout layout = LayoutUtil::MakeLayout({0, 1, 2, 3});
+
+ const float pad_value = -5.123f;
+ Array4D<float> input_array(1, 25, 7, 7);
+ input_array.Fill(pad_value);
+ input_array(0, 0, 0, 0) = 1.0f;
+ input_array(0, 24, 6, 6) = 2.0f;
+ input_array(0, 17, 2, 5) = 3.0f;
+ auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
+ input = LiteralUtil::Relayout(*input, layout);
+
+ b.Pad(b.ConstantLiteral(*input), b.ConstantR0(pad_value), padding_config);
+
+ Array4D<float> expected_array(1, 25, 17, 11);
+ expected_array.Fill(pad_value);
+ expected_array(0, 0, 2, 2) = 1.0f;
+ expected_array(0, 24, 14, 8) = 2.0f;
+ expected_array(0, 17, 6, 7) = 3.0f;
+ ComputeAndCompareR4<float>(&b, expected_array, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(PadTest, Pad4DU8Array) {
+ ComputationBuilder b(client_, TestName());
+ auto input = MakeUnique<Array4D<uint8>>(1, 1, 3, 2);
+ Array2D<uint8> input_xy({
+ {1, 2}, // row 0
+ {3, 4}, // row 1
+ {5, 6}, // row 2
+ });
+ input->FillWithYX(input_xy);
+
+ b.Pad(b.ConstantR4FromArray4D<uint8>(*input), b.ConstantR0<uint8>(35),
+ r4_padding_on_dim0_dim1_);
+
+ auto expected = MakeUnique<Array4D<uint8>>(2, 3, 3, 2);
+ expected->Fill(35);
+ (*expected)(1, 0, 0, 0) = 1;
+ (*expected)(1, 0, 0, 1) = 2;
+ (*expected)(1, 0, 1, 0) = 3;
+ (*expected)(1, 0, 1, 1) = 4;
+ (*expected)(1, 0, 2, 0) = 5;
+ (*expected)(1, 0, 2, 1) = 6;
+ ComputeAndCompareR4<uint8>(&b, *expected, {});
+}
+
+XLA_TEST_F(PadTest, Pad4DPredArray) {
+ ComputationBuilder b(client_, TestName());
+
+ // Since bool is currently not well supported, use Broadcast operation to
+ // create the operand for Pad.
+ auto input = b.Broadcast(b.ConstantR0<bool>(true), {1, 1, 3, 2});
+ auto padded =
+ b.Pad(input, b.ConstantR0<bool>(false), r4_padding_on_dim0_dim1_);
+
+ // For the same reason, use Select to convert boolean values to int32.
+ auto zeros = MakeUnique<Array4D<int32>>(2, 3, 3, 2);
+ auto ones = MakeUnique<Array4D<int32>>(2, 3, 3, 2);
+ zeros->Fill(0);
+ ones->Fill(1);
+ b.Select(padded, b.ConstantR4FromArray4D<int32>(*ones),
+ b.ConstantR4FromArray4D<int32>(*zeros));
+
+ auto expected = MakeUnique<Array4D<int32>>(2, 3, 3, 2);
+ expected->Fill(0);
+ (*expected)(1, 0, 0, 0) = 1;
+ (*expected)(1, 0, 0, 1) = 1;
+ (*expected)(1, 0, 1, 0) = 1;
+ (*expected)(1, 0, 1, 1) = 1;
+ (*expected)(1, 0, 2, 0) = 1;
+ (*expected)(1, 0, 2, 1) = 1;
+ ComputeAndCompareR4<int32>(&b, *expected, {});
+}
+
+XLA_TEST_F(PadTest, Large2DPad) {
+ ComputationBuilder b(client_, TestName());
+
+ auto input = b.Parameter(0, ShapeUtil::MakeShape(F32, {4, 4}), "input");
+ PaddingConfig padding_config = MakeNoPaddingConfig(2);
+ for (int dim : {0, 1}) {
+ padding_config.mutable_dimensions(dim)->set_edge_padding_low(
+ 98 + 100 * (1 - dim));
+ padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 +
+ 100 * dim);
+ }
+ auto padded = b.Pad(input, b.ConstantR0<float>(0.0f), padding_config);
+
+ auto ones = MakeUnique<Array2D<float>>(4, 4);
+ ones->Fill(1.0f);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D<float>(*ones);
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f);
+ ComputeAndCompareR2<float>(&b, *expected, {input_data.get()});
+}
+
+XLA_TEST_F(PadTest, AllTypes2DPad) {
+ ComputationBuilder b(client_, TestName());
+
+ constexpr int64 in_rows = 35;
+ constexpr int64 in_cols = 35;
+ auto input =
+ b.Parameter(0, ShapeUtil::MakeShape(F32, {in_rows, in_cols}), "input");
+ PaddingConfig padding_config = MakeNoPaddingConfig(2);
+ padding_config.mutable_dimensions(0)->set_edge_padding_low(7);
+ padding_config.mutable_dimensions(0)->set_edge_padding_high(5);
+ padding_config.mutable_dimensions(0)->set_interior_padding(3);
+ padding_config.mutable_dimensions(1)->set_edge_padding_low(6);
+ padding_config.mutable_dimensions(1)->set_edge_padding_high(4);
+ padding_config.mutable_dimensions(1)->set_interior_padding(2);
+ auto padded = b.Pad(input, b.ConstantR0<float>(3.14f), padding_config);
+
+ auto operand = MakeUnique<Array2D<float>>(in_rows, in_cols);
+ operand->FillUnique(0.0f);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D<float>(*operand);
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f);
+ ComputeAndCompareR2<float>(&b, *expected, {input_data.get()},
+ ErrorSpec{0.0001});
+}
+
+XLA_TEST_F(PadTest, High2DPad) {
+ ComputationBuilder b(client_, TestName());
+
+ constexpr int64 in_rows = 129;
+ constexpr int64 in_cols = 129;
+ constexpr int64 low_padding = 0;
+ int64 high_padding[2] = {5, 7};
+ constexpr int64 interior_padding = 0;
+ auto input =
+ b.Parameter(0, ShapeUtil::MakeShape(F32, {in_rows, in_cols}), "input");
+ PaddingConfig padding_config = MakeNoPaddingConfig(2);
+ for (int dim : {0, 1}) {
+ padding_config.mutable_dimensions(dim)->set_edge_padding_low(low_padding);
+ padding_config.mutable_dimensions(dim)->set_edge_padding_high(
+ high_padding[dim]);
+ padding_config.mutable_dimensions(dim)->set_interior_padding(
+ interior_padding);
+ }
+ auto padded = b.Pad(input, b.ConstantR0<float>(2.718f), padding_config);
+
+ auto operand = MakeUnique<Array2D<float>>(in_rows, in_cols);
+ operand->FillUnique(1.0f);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D<float>(*operand);
+ auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ ComputeAndCompareR2<float>(&b, *expected, {input_data.get()},
+ ErrorSpec(0.0001));
+}
+
+// Regression test for b/31827337.
+XLA_TEST_F(PadTest, ReducePad) {
+ ComputationBuilder b(client_, TestName());
+ auto input = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "input");
+
+ Computation add_f32 = CreateScalarAddComputation(F32, &b);
+ auto reduce = b.Reduce(input, b.ConstantR0<float>(0.0), add_f32, {0});
+
+ PaddingConfig padding_config = MakeNoPaddingConfig(3);
+ padding_config.mutable_dimensions(0)->set_edge_padding_low(1);
+ padding_config.mutable_dimensions(0)->set_edge_padding_high(1);
+ auto pad = b.Pad(reduce, b.ConstantR0<float>(0.0), padding_config);
+
+ auto ones = MakeUnique<Array4D<float>>(2, 2, 2, 2);
+ ones->Fill(1.0);
+ auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(*ones);
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ Array3D<float> expected({{{0.0, 0.0}, {0.0, 0.0}},
+ {{2.0, 2.0}, {2.0, 2.0}},
+ {{2.0, 2.0}, {2.0, 2.0}},
+ {{0.0, 0.0}, {0.0, 0.0}}});
+ ComputeAndCompareR3<float>(&b, expected, {input_data.get()});
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc
new file mode 100644
index 0000000000..2f05576cee
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/params_test.cc
@@ -0,0 +1,357 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ParamsTest : public ClientLibraryTestBase {};
+
+XLA_TEST_F(ParamsTest, ConstantR0F32Param) {
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR0<float>(3.14159f);
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0");
+
+ ComputeAndCompareR0<float>(&builder, 3.14159f, {param0_data.get()},
+ ErrorSpec(0.0001f));
+}
+
+XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) {
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0}), "param0");
+
+ ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
+ ErrorSpec(0.01f));
+}
+
+XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) {
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({3.14f, -100.25f});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "param0");
+
+ ComputeAndCompareR1<float>(&builder, {3.14f, -100.25f}, {param0_data.get()},
+ ErrorSpec(0.01f));
+}
+
+XLA_TEST_F(ParamsTest, ConstantR1U8Param) {
+ ComputationBuilder builder(client_, TestName());
+ string str("hello world");
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1U8(str);
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto p = builder.Parameter(
+ 0, ShapeUtil::MakeShape(U8, {static_cast<int64>(str.size())}), "param0");
+
+ ComputeAndCompareR1U8(&builder, str, {param0_data.get()});
+}
+
+XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) {
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(3, 0));
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3, 0}), "param0");
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(3, 0),
+ {param0_data.get()}, ErrorSpec(0.01f));
+}
+
+XLA_TEST_F(ParamsTest, ConstantR2F32Param) {
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2<float>(
+ {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3, 2}), "param0");
+
+ Array2D<float> expected_array(
+ {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}});
+ ComputeAndCompareR2<float>(&builder, expected_array, {param0_data.get()},
+ ErrorSpec(0.01f));
+}
+
+XLA_TEST_F(ParamsTest, TwoParameters) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*literal0).ConsumeValueOrDie();
+ auto param0 = builder.Parameter(0, literal0->shape(), "param0");
+
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>({10, 20});
+ std::unique_ptr<GlobalData> param1_data =
+ client_->TransferToServer(*literal1).ConsumeValueOrDie();
+ auto param1 = builder.Parameter(1, literal1->shape(), "param1");
+
+ // Use both parameters
+ //
+ // {1, 2} + {10, 20} = {11, 22}
+ auto sum = builder.Add(param0, param1);
+ sum = builder.Add(param0, param1);
+
+ // Use only the second parameter again, to show that it can be used
+ // twice and to make the computation asymmetric in the two
+ // parameters to test that the parameters are not swapped.
+ //
+ // {11, 22} * {10, 20} = {110, 440}
+ auto prod = builder.Mul(sum, param1);
+
+ ComputeAndCompareR1<float>(&builder, {110, 440},
+ {param0_data.get(), param1_data.get()},
+ ErrorSpec(0.0001f));
+}
+
+XLA_TEST_F(ParamsTest, MissingParameter) {
+ // Test that an error is returned when a computation with an incomplete set of
+ // parameters (parameter numbers not contiguous from 0) is executed.
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<float>(3.14159f);
+ std::unique_ptr<GlobalData> data =
+ client_->TransferToServer(*literal).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto p = builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "param2");
+ auto computation = builder.Build().ConsumeValueOrDie();
+
+ auto execute_status = client_->Execute(computation, {data.get(), data.get()},
+ /*output_layout=*/nullptr,
+ /*execution_profile=*/nullptr);
+ ASSERT_EQ(execute_status.status().code(),
+ tensorflow::error::FAILED_PRECONDITION);
+}
+
+XLA_TEST_F(ParamsTest, UnusedParameter) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*literal0).ConsumeValueOrDie();
+ auto param0 = builder.Parameter(0, literal0->shape(), "param0");
+
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>({10, 20});
+ std::unique_ptr<GlobalData> param1_data =
+ client_->TransferToServer(*literal1).ConsumeValueOrDie();
+ auto param1 = builder.Parameter(1, literal1->shape(), "param1");
+
+ ComputeAndCompareR1<float>(&builder, {10, 20},
+ {param0_data.get(), param1_data.get()},
+ ErrorSpec(0.0001f));
+}
+
+XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) {
+ // Build a computation with a couple unused parameters which are used in an
+ // unused expression.
+ ComputationBuilder builder(client_, TestName());
+
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*literal0).ConsumeValueOrDie();
+
+ std::unique_ptr<Literal> literal1 =
+ LiteralUtil::CreateR1<float>({10, 20, 30});
+ std::unique_ptr<GlobalData> param1_data =
+ client_->TransferToServer(*literal1).ConsumeValueOrDie();
+
+ auto param0 = builder.Parameter(0, literal0->shape(), "param0");
+ auto param1 = builder.Parameter(1, literal1->shape(), "param1");
+ auto param2 = builder.Parameter(2, literal1->shape(), "param2");
+
+ // This add is unused.
+ builder.Add(param1, param2);
+
+ builder.Neg(param0);
+
+ ComputeAndCompareR1<float>(
+ &builder, {-1, -2},
+ {param0_data.get(), param1_data.get(), param1_data.get()},
+ ErrorSpec(0.0001f));
+}
+
+XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) {
+ ComputationBuilder builder(client_, TestName());
+ constexpr int size = 8 * 128 * 2;
+
+ std::vector<float> init_value = {{0, 1}};
+ init_value.resize(size);
+ ComputationDataHandle sum_handle = builder.ConstantR1<float>(init_value);
+ std::vector<float> sum = {{0, 1}};
+ sum.resize(size);
+
+ std::vector<std::unique_ptr<GlobalData>> param_data_owner;
+
+ constexpr int parameter_count = 100;
+ for (int i = 0; i < parameter_count; ++i) {
+ const float entry0 = i;
+ const float entry1 = 2 * i;
+ sum[0] += entry0;
+ sum[1] += entry1;
+
+ std::vector<float> sum_value = {{entry0, entry1}};
+ sum_value.resize(size);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<float>(sum_value);
+ param_data_owner.push_back(
+ client_->TransferToServer(*literal).ConsumeValueOrDie());
+ ComputationDataHandle param =
+ builder.Parameter(i, literal->shape(), "param");
+ sum_handle = builder.Add(sum_handle, param);
+ }
+
+ std::vector<GlobalData*> param_data;
+ for (const std::unique_ptr<GlobalData>& data : param_data_owner) {
+ param_data.push_back(data.get());
+ }
+
+ ComputeAndCompareR1<float>(&builder, sum, param_data, ErrorSpec(0.0001f));
+}
+
+XLA_TEST_F(ParamsTest,
+ DISABLED_ON_CPU_PARALLEL(TupleOfR1ParametersAddedTogether)) {
+ ComputationBuilder builder(client_, TestName());
+
+ Shape r1f32_3 = ShapeUtil::MakeShape(F32, {3});
+ Shape tuple_shape = ShapeUtil::MakeTupleShape({r1f32_3, r1f32_3});
+ auto input = builder.Parameter(0, tuple_shape, "input");
+ auto lhs = builder.GetTupleElement(input, 0);
+ auto rhs = builder.GetTupleElement(input, 1);
+ builder.Add(lhs, rhs);
+
+ std::unique_ptr<GlobalData> data =
+ client_
+ ->TransferToServer(*LiteralUtil::MakeTuple({
+ LiteralUtil::CreateR1<float>({1, 2, 3}).get(),
+ LiteralUtil::CreateR1<float>({4, 5, 6}).get(),
+ }))
+ .ConsumeValueOrDie();
+
+ std::vector<GlobalData*> arguments = {data.get()};
+ const std::vector<float> expected = {1 + 4, 2 + 5, 3 + 6};
+ ComputeAndCompareR1<float>(&builder, expected, arguments, ErrorSpec(1e-5));
+}
+
+// Verifies that passing a 2x2 with {0, 1} layout returns the same value back
+// when (transferred to the server and) passed through a parameter.
+XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) {
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR2<float>({
+ {1, 2}, {3, 4},
+ });
+ *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+ ComputationBuilder builder(client_, TestName());
+ builder.Parameter(0, literal->shape(), "input");
+
+ std::unique_ptr<GlobalData> data =
+ client_->TransferToServer(*literal).ConsumeValueOrDie();
+ ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3));
+}
+
+// As above, but for {1, 0} layout.
+XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) {
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR2<float>({
+ {1, 3}, {2, 4},
+ });
+ *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
+ ComputationBuilder builder(client_, TestName());
+ builder.Parameter(0, literal->shape(), "input");
+
+ std::unique_ptr<GlobalData> data =
+ client_->TransferToServer(*literal).ConsumeValueOrDie();
+ ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3));
+}
+
+XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR2<float>({
+ {1, 3}, {2, 4},
+ });
+ const Shape original = literal->shape();
+ {
+ // Reverse the layout present in original, and make that the layout of the
+ // literal.
+ std::vector<int64> original_layout(
+ original.layout().minor_to_major().begin(),
+ original.layout().minor_to_major().end());
+ std::reverse(original_layout.begin(), original_layout.end());
+ *literal->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout(original_layout);
+ ASSERT_EQ(2, LiteralUtil::Get<float>(*literal, {0, 1}));
+ }
+ // Use the original shape in building the computation.
+ ComputationBuilder builder(client_, TestName());
+ auto input = builder.Parameter(0, original, "input");
+ // Use the slice operator to get an off-diagonal element.
+ builder.Slice(input, {0, 1}, {1, 2});
+
+ std::unique_ptr<GlobalData> data =
+ client_->TransferToServer(*literal).ConsumeValueOrDie();
+ // Check that we got the off-diagonal value that we expected.
+ Array2D<float> expected(1, 1);
+ expected(0, 0) = 2;
+ ComputeAndCompareR2(&builder, expected, {data.get()}, ErrorSpec(1e-3));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc
new file mode 100644
index 0000000000..96393c41e8
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/pred_test.cc
@@ -0,0 +1,115 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Miscellaneous tests with the PRED type that don't fit anywhere else.
+#include <memory>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class PredTest : public ClientLibraryTestBase {
+ protected:
+ void TestCompare(bool lhs, bool rhs, bool expected,
+ ComputationDataHandle (ComputationBuilder::*op)(
+ const ComputationDataHandle&,
+ const ComputationDataHandle&,
+ tensorflow::gtl::ArraySlice<int64>)) {
+ ComputationBuilder builder(client_, TestName());
+ ComputationDataHandle lhs_op = builder.ConstantR0<bool>(lhs);
+ ComputationDataHandle rhs_op = builder.ConstantR0<bool>(rhs);
+ ComputationDataHandle result = (builder.*op)(lhs_op, rhs_op, {});
+ ComputeAndCompareR0<bool>(&builder, expected, {});
+ }
+};
+
+TEST_F(PredTest, ConstantR0PredTrue) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR0<bool>(true);
+ ComputeAndCompareR0<bool>(&builder, true, {});
+}
+
+TEST_F(PredTest, ConstantR0PredFalse) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR0<bool>(false);
+ ComputeAndCompareR0<bool>(&builder, false, {});
+}
+
+TEST_F(PredTest, ConstantR0PredCompareEq) {
+ TestCompare(true, false, false, &ComputationBuilder::Eq);
+}
+
+TEST_F(PredTest, ConstantR0PredCompareNe) {
+ TestCompare(true, false, true, &ComputationBuilder::Ne);
+}
+
+TEST_F(PredTest, ConstantR0PredCompareLe) {
+ TestCompare(true, false, false, &ComputationBuilder::Le);
+}
+
+TEST_F(PredTest, ConstantR0PredCompareLt) {
+ TestCompare(true, false, false, &ComputationBuilder::Lt);
+}
+
+TEST_F(PredTest, ConstantR0PredCompareGe) {
+ TestCompare(true, false, true, &ComputationBuilder::Ge);
+}
+
+TEST_F(PredTest, ConstantR0PredCompareGt) {
+ TestCompare(true, false, true, &ComputationBuilder::Gt);
+}
+
+TEST_F(PredTest, ConstantR1Pred) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<bool>({true, false, false, true});
+ ComputeAndCompareR1<bool>(&builder, {true, false, false, true}, {});
+}
+
+TEST_F(PredTest, ConstantR2Pred) {
+ ComputationBuilder builder(client_, TestName());
+ auto a =
+ builder.ConstantR2<bool>({{false, true, true}, {true, false, false}});
+ const string expected = R"(pred[2,3] {
+ { 011 },
+ { 100 },
+})";
+ EXPECT_EQ(expected, ExecuteToString(&builder, {}));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc
new file mode 100644
index 0000000000..8d77b3dd61
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/prng_test.cc
@@ -0,0 +1,238 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class PrngTest : public ClientLibraryTestBase {
+ protected:
+ template <typename T>
+ void UniformTest(T a, T b, tensorflow::gtl::ArraySlice<int64> dims);
+ void BernoulliTest(float p, tensorflow::gtl::ArraySlice<int64> dims);
+};
+
+template <typename T>
+void PrngTest::UniformTest(T a, T b, tensorflow::gtl::ArraySlice<int64> dims) {
+ ComputationBuilder builder(client_, TestName());
+ builder.RngUniform(
+ builder.ConstantR0<T>(a), builder.ConstantR0<T>(b),
+ ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<T>(), dims));
+
+ auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{});
+ EXPECT_TRUE(ContainersEqual(dims, actual->shape().dimensions()));
+ LiteralUtil::EachCell<T>(*actual,
+ [=](tensorflow::gtl::ArraySlice<int64>, T value) {
+ EXPECT_LE(a, value);
+ EXPECT_GE(b, value);
+ });
+}
+
+void PrngTest::BernoulliTest(float p, tensorflow::gtl::ArraySlice<int64> dims) {
+ ComputationBuilder builder(client_, TestName());
+ auto shape = ShapeUtil::MakeShape(U32, dims);
+ builder.RngBernoulli(builder.ConstantR0<float>(p), shape);
+
+ TF_ASSIGN_OR_ASSERT_OK(auto computation, builder.Build());
+ constexpr uint64 kTestSeed = 42;
+ TF_ASSIGN_OR_ASSERT_OK(
+ auto actual,
+ client_->ExecuteAndTransfer(computation, /*arguments=*/{},
+ /*shape_with_output_layout=*/nullptr,
+ /*execution_profile=*/nullptr,
+ /*seed=*/kTestSeed));
+ EXPECT_TRUE(ContainersEqual(dims, actual->shape().dimensions()));
+ int32 sum = 0;
+ LiteralUtil::EachCell<uint32>(
+ *actual, [&sum](tensorflow::gtl::ArraySlice<int64>, uint32 value) {
+ EXPECT_TRUE(value == 0 || value == 1);
+ sum += value;
+ });
+ int32 total = ShapeUtil::ElementsIn(shape);
+ float p_tilde = sum / static_cast<float>(total);
+
+ // Test within expected range using normal approximation. The test uses a
+ // fixed seed and has a fixed output per p and backend. Using the normal
+ // approximation as this test is invoked for different `p` and the different
+ // backends could use different random number generators and produce different
+ // values. Choose 95% confidence level, so that z_{1-\alpha/2} = 1.96.
+ float normal_approximation_term = 1.96 * sqrt(p * (1 - p) / total);
+ EXPECT_GE(p_tilde, p - normal_approximation_term);
+ EXPECT_LE(p_tilde, p + normal_approximation_term);
+}
+
+// Uniform random number generation tests
+XLA_TEST_F(PrngTest, ScalarU01) { UniformTest<float>(0, 1, {}); }
+XLA_TEST_F(PrngTest, ZeroValuesU01) { UniformTest<float>(0, 1, {0}); }
+XLA_TEST_F(PrngTest, TenValuesU01) { UniformTest<float>(0, 1, {10}); }
+XLA_TEST_F(PrngTest, TenValuesU37) { UniformTest<float>(3, 7, {10}); }
+XLA_TEST_F(PrngTest, ZeroValuesR2) { UniformTest<float>(0, 1, {0, 20}); }
+XLA_TEST_F(PrngTest, LargeU01) { UniformTest<float>(0, 1, {0x100, 0x100}); }
+XLA_TEST_F(PrngTest, TwelveValuesU524) { UniformTest<int32>(5, 24, {12}); }
+
+XLA_TEST_F(PrngTest, MapUsingRng) {
+ // Build a x -> (x + U[0,1)) computation.
+ auto build_sum_rng = [this](ComputationBuilder& builder) {
+ auto b = builder.CreateSubBuilder("sum_with_rng");
+ auto x = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "input");
+ b->Add(x,
+ b->RngUniform(b->ConstantR0<float>(0), b->ConstantR0<float>(1),
+ ShapeUtil::MakeShape(F32, {})));
+ return b->BuildAndNoteError();
+ };
+
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({2.2f, 5.3f, 4.4f, 5.5f});
+ TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<GlobalData> param0_data,
+ client_->TransferToServer(*param0_literal));
+
+ auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto fn = build_sum_rng(builder);
+ builder.Map({param0}, fn);
+
+ TF_ASSIGN_OR_ASSERT_OK(auto computation, builder.Build());
+ TF_ASSIGN_OR_ASSERT_OK(
+ auto actual,
+ client_->ExecuteAndTransfer(computation,
+ /*arguments=*/{param0_data.get()}, nullptr,
+ nullptr, /*seed=*/125));
+ EXPECT_EQ(actual->f32s_size(), param0_literal->f32s_size());
+ for (int i = 0; i < param0_literal->f32s_size(); ++i) {
+ EXPECT_GE(actual->f32s(i), param0_literal->f32s(i));
+ EXPECT_LT(actual->f32s(i), param0_literal->f32s(i) + 1.0f);
+ }
+}
+
+// This tests demonstrates the global seeding behaviour.
+// * If a seed is passed in via Execute (ExecuteAndTransfer) then the output is
+// fixed (i.e., there is a single output for a given seed);
+// * If no seed is passed in then the output of every call can be different;
+XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
+ // Build a U[0,1) computation.
+ auto build_computation = [this]() {
+ ComputationBuilder builder(client_, TestName());
+ builder.RngUniform(builder.ConstantR0<float>(0),
+ builder.ConstantR0<float>(1),
+ ShapeUtil::MakeShape(F32, {10}));
+ return builder.Build();
+ };
+
+ std::unique_ptr<Literal> result1;
+ {
+ TF_ASSIGN_OR_ASSERT_OK(auto computation, build_computation());
+ TF_ASSIGN_OR_ASSERT_OK(
+ result1,
+ client_->ExecuteAndTransfer(computation, /*arguments=*/{},
+ /*shape_with_output_layout=*/nullptr,
+ /*execution_profile=*/nullptr,
+ /*seed=*/42));
+ }
+ std::unique_ptr<Literal> result2;
+ std::unique_ptr<Literal> result3;
+ {
+ TF_ASSIGN_OR_ASSERT_OK(auto computation, build_computation());
+ TF_ASSIGN_OR_ASSERT_OK(
+ result2,
+ client_->ExecuteAndTransfer(computation, /*arguments=*/{},
+ /*shape_with_output_layout=*/nullptr,
+ /*execution_profile=*/nullptr,
+ /*seed=*/42));
+ TF_ASSIGN_OR_ASSERT_OK(
+ result3,
+ client_->ExecuteAndTransfer(computation, /*arguments=*/{},
+ /*shape_with_output_layout=*/nullptr,
+ /*execution_profile=*/nullptr,
+ /*seed=*/42));
+ }
+
+ std::unique_ptr<Literal> result4;
+ std::unique_ptr<Literal> result5;
+ std::unique_ptr<Literal> result6;
+ {
+ TF_ASSIGN_OR_ASSERT_OK(auto computation, build_computation());
+ TF_ASSIGN_OR_ASSERT_OK(
+ result4,
+ client_->ExecuteAndTransfer(computation, /*arguments=*/{},
+ /*shape_with_output_layout=*/nullptr,
+ /*execution_profile=*/nullptr,
+ /*seed=*/65));
+ TF_ASSIGN_OR_ASSERT_OK(
+ result5,
+ client_->ExecuteAndTransfer(computation, /*arguments=*/{},
+ /*shape_with_output_layout=*/nullptr,
+ /*execution_profile=*/nullptr));
+ TF_ASSIGN_OR_ASSERT_OK(
+ result6,
+ client_->ExecuteAndTransfer(computation, /*arguments=*/{},
+ /*shape_with_output_layout=*/nullptr,
+ /*execution_profile=*/nullptr));
+ }
+
+ LiteralTestUtil::ExpectEqual(*result1, *result2);
+ LiteralTestUtil::ExpectEqual(*result1, *result3);
+ LiteralTestUtil::ExpectNotEqual(*result1, *result4);
+ LiteralTestUtil::ExpectNotEqual(*result4, *result5);
+ LiteralTestUtil::ExpectNotEqual(*result5, *result6);
+}
+
+// Bernoulli random number generation tests
+XLA_TEST_F(PrngTest, HundredValuesB10p5) { BernoulliTest(0.5, {100}); }
+XLA_TEST_F(PrngTest, HundredValuesB10p1) { BernoulliTest(0.1, {100}); }
+
+XLA_TEST_F(PrngTest, TenValuesN01) {
+ ComputationBuilder builder(client_, TestName());
+ builder.RngNormal(builder.ConstantR0<float>(0), builder.ConstantR0<float>(1),
+ ShapeUtil::MakeShape(F32, {10}));
+
+ ExecuteAndTransferOrDie(&builder, /*arguments=*/{});
+ // TODO(b/25995601): Test that resultant values are reasonable
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc
new file mode 100644
index 0000000000..eb7e63705b
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc
@@ -0,0 +1,61 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class QueryInferredShapeTest : public ClientLibraryTestBase {};
+
+TEST_F(QueryInferredShapeTest, OnePlusOneShape) {
+ ComputationBuilder builder(client_, "one_plus_one");
+ auto one = builder.ConstantR0<float>(1.0);
+ auto result = builder.Add(one, one);
+ StatusOr<std::unique_ptr<Shape>> shape_status = builder.GetShape(result);
+ ASSERT_IS_OK(shape_status.status());
+ auto shape = shape_status.ConsumeValueOrDie();
+ ASSERT_TRUE(ShapeUtil::Equal(*shape, ShapeUtil::MakeShape(F32, {})));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
new file mode 100644
index 0000000000..f3d8da5c8c
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -0,0 +1,506 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests that multi-dimensional arrays can be reduced among various
+// user-provided dimensions.
+//
+// Note that comments for these tests are white-box in that they talk about the
+// default data layout.
+//
+// The test space for reductions is the cartesian product of:
+//
+// <possible ranks> x
+// <possible layouts for chosen rank> x
+// <possible subsets of dimensions in chosen rank>
+
+#include <stdlib.h>
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ReduceTest : public ClientLibraryTestBase {
+ protected:
+ ReduceTest() {
+ // Implementation note: layed out z >> y >> x by default.
+ // clang-format off
+ literal_2d_ = LiteralUtil::CreateR2<float>({
+ // x0 x1 x2
+ { 1.f, 2.f, 3.f}, // y0
+ { 4.f, 5.f, 6.f}, // y1
+ });
+ literal_3d_ = LiteralUtil::CreateR3Projected<float>({
+ // x0 x1 x2
+ { 1.f, 2.f, 3.f}, // y0
+ { 4.f, 5.f, 6.f}, // y1
+ }, 4);
+ // clang-format on
+ CHECK(ShapeUtil::Equal(
+ literal_3d_->shape(),
+ ShapeUtil::MakeShape(F32, {/*z=*/4, /*y=*/2, /*x=*/3})))
+ << literal_3d_->shape().ShortDebugString();
+ }
+
+ // Runs an R1 => R0 reduction test with the given number of elements.
+ void RunR1ToR0Test(int64 element_count) {
+ ComputationBuilder builder(client_, TestName());
+ Computation add_f32 = CreateScalarAddComputation(F32, &builder);
+ const Shape input_shape = ShapeUtil::MakeShape(F32, {element_count});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto zero = builder.ConstantR0<float>(0.0);
+ builder.Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0});
+
+ std::vector<float> input_data(element_count);
+ for (int64 i = 0; i < element_count; ++i) {
+ input_data[i] = rand_r(&seed_) % 3;
+ if (rand_r(&seed_) % 2 == 0) {
+ input_data[i] *= -1;
+ }
+ }
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR1(AsSlice(input_data));
+ std::unique_ptr<GlobalData> input_global_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ float expected = 0.0;
+ for (float item : input_data) {
+ expected += item;
+ }
+ ComputeAndCompareR0<float>(&builder, expected, {input_global_data.get()},
+ ErrorSpec(0.001));
+ }
+
+ // Runs an R2 => R0 reduction test with the given number of (rows, cols).
+ void RunR2ToR0Test(int64 rows, int64 cols, int64 minor = 1, int64 major = 0) {
+ ComputationBuilder builder(client_, TestName());
+ Computation add_f32 = CreateScalarAddComputation(F32, &builder);
+ const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto zero = builder.ConstantR0<float>(0.0);
+ builder.Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0, 1});
+
+ Array2D<float> input_data(rows, cols);
+ input_data.FillRandom(3.14f, 0.04);
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR2FromArray2D(input_data);
+ input_literal = LiteralUtil::Relayout(
+ *input_literal, LayoutUtil::MakeLayout({minor, major}));
+ std::unique_ptr<GlobalData> input_global_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ float expected = 0.0;
+ for (int64 rowno = 0; rowno < rows; ++rowno) {
+ for (int64 colno = 0; colno < cols; ++colno) {
+ expected += input_data(rowno, colno);
+ }
+ }
+ ComputeAndCompareR0<float>(&builder, expected, {input_global_data.get()},
+ ErrorSpec(0.01, 1e-4));
+ }
+
+ // Runs an R2 => R1 reduction test with the given number of (rows, cols).
+ void RunR2ToR1Test(int64 rows, int64 cols, int64 minor = 1, int64 major = 0) {
+ ComputationBuilder builder(client_, TestName());
+ Computation add_f32 = CreateScalarAddComputation(F32, &builder);
+ const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto zero = builder.ConstantR0<float>(0.0);
+ builder.Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0});
+
+ Array2D<float> input_data(rows, cols);
+ input_data.FillRandom(3.14f, 0.04);
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR2FromArray2D(input_data);
+ input_literal = LiteralUtil::Relayout(
+ *input_literal, LayoutUtil::MakeLayout({minor, major}));
+ std::unique_ptr<GlobalData> input_global_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ std::vector<float> expected;
+ for (int64 colno = 0; colno < cols; ++colno) {
+ float column_sum = 0;
+ for (int64 rowno = 0; rowno < rows; ++rowno) {
+ column_sum += input_data(rowno, colno);
+ }
+ expected.push_back(column_sum);
+ }
+ ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
+ ErrorSpec(0.01, 1e-4));
+ }
+
+ std::unique_ptr<Literal> literal_2d_;
+ std::unique_ptr<Literal> literal_3d_;
+ uint32 seed_ = 0xdeadbeef;
+};
+
+XLA_TEST_F(ReduceTest, ReduceR1_0_F32_To_R0) { RunR1ToR0Test(0); }
+XLA_TEST_F(ReduceTest, ReduceR1_1_F32_To_R0) { RunR1ToR0Test(1); }
+XLA_TEST_F(ReduceTest, ReduceR1_2_F32_To_R0) { RunR1ToR0Test(2); }
+XLA_TEST_F(ReduceTest, ReduceR1_16_F32_To_R0) { RunR1ToR0Test(16); }
+XLA_TEST_F(ReduceTest, ReduceR1_240_F32_To_R0) { RunR1ToR0Test(240); }
+XLA_TEST_F(ReduceTest, ReduceR1_128_F32_To_R0) { RunR1ToR0Test(128); }
+XLA_TEST_F(ReduceTest, ReduceR1_129_F32_To_R0) { RunR1ToR0Test(129); }
+XLA_TEST_F(ReduceTest, ReduceR1_256_F32_To_R0) { RunR1ToR0Test(256); }
+XLA_TEST_F(ReduceTest, ReduceR1_1024_F32_To_R0) { RunR1ToR0Test(1024); }
+XLA_TEST_F(ReduceTest, ReduceR1_2048_F32_To_R0) { RunR1ToR0Test(2048); }
+XLA_TEST_F(ReduceTest, ReduceR1_16K_F32_To_R0) { RunR1ToR0Test(16 * 1024); }
+XLA_TEST_F(ReduceTest, ReduceR1_16KP1_F32_To_R0) {
+ RunR1ToR0Test(16 * 1024 + 1);
+}
+
+XLA_TEST_F(ReduceTest, ReduceR2_0x0_To_R0) { RunR2ToR0Test(0, 0); }
+XLA_TEST_F(ReduceTest, ReduceR2_0x2_To_R0) { RunR2ToR0Test(0, 2); }
+XLA_TEST_F(ReduceTest, ReduceR2_1x1_To_R0) { RunR2ToR0Test(1, 1); }
+XLA_TEST_F(ReduceTest, ReduceR2_2x0_To_R0) { RunR2ToR0Test(2, 0); }
+XLA_TEST_F(ReduceTest, ReduceR2_2x2_To_R0) { RunR2ToR0Test(2, 2); }
+XLA_TEST_F(ReduceTest, ReduceR2_8x8_To_R0) { RunR2ToR0Test(8, 8); }
+XLA_TEST_F(ReduceTest, ReduceR2_9x9_To_R0) { RunR2ToR0Test(9, 9); }
+XLA_TEST_F(ReduceTest, ReduceR2_50x111_To_R0) { RunR2ToR0Test(50, 111); }
+XLA_TEST_F(ReduceTest, ReduceR2_111x50_To_R0) { RunR2ToR0Test(111, 50); }
+XLA_TEST_F(ReduceTest, ReduceR2_111x50_01_To_R0) {
+ RunR2ToR0Test(111, 50, 0, 1);
+}
+XLA_TEST_F(ReduceTest, ReduceR2_1024x1024_To_R0) { RunR2ToR0Test(1024, 1024); }
+XLA_TEST_F(ReduceTest, ReduceR2_1000x1500_To_R0) { RunR2ToR0Test(1000, 1500); }
+
+// Disabled due to b/33245142. Failed on 2016-11-30.
+// XLA_TEST_F(ReduceTest, ReduceR2_0x0_To_R1) { RunR2ToR1Test(0, 0); }
+XLA_TEST_F(ReduceTest, ReduceR2_0x2_To_R1) { RunR2ToR1Test(0, 2); }
+XLA_TEST_F(ReduceTest, ReduceR2_1x1_To_R1) { RunR2ToR1Test(1, 1); }
+// Disabled due to b/33245142. Failed on 2016-11-30.
+// XLA_TEST_F(ReduceTest, ReduceR2_2x0_To_R1) { RunR2ToR1Test(2, 0); }
+XLA_TEST_F(ReduceTest, ReduceR2_2x2_To_R1) { RunR2ToR1Test(2, 2); }
+XLA_TEST_F(ReduceTest, ReduceR2_8x8_To_R1) { RunR2ToR1Test(8, 8); }
+XLA_TEST_F(ReduceTest, ReduceR2_9x9_To_R1) { RunR2ToR1Test(9, 9); }
+XLA_TEST_F(ReduceTest, ReduceR2_50x111_To_R1) { RunR2ToR1Test(50, 111); }
+XLA_TEST_F(ReduceTest, ReduceR2_111x50_To_R1) { RunR2ToR1Test(111, 50); }
+XLA_TEST_F(ReduceTest, ReduceR2_111x50_01_To_R1) {
+ RunR2ToR1Test(111, 50, 0, 1);
+}
+XLA_TEST_F(ReduceTest, ReduceR2_1024x1024_To_R1) { RunR2ToR1Test(1024, 1024); }
+XLA_TEST_F(ReduceTest, ReduceR2_1000x1500_To_R1) { RunR2ToR1Test(1000, 1500); }
+
+XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) {
+ const int64 rows = 111, cols = 50;
+
+ ComputationBuilder builder(client_, TestName());
+ Computation add_f32 = CreateScalarAddComputation(F32, &builder);
+ const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto zero = builder.ConstantR0<float>(0.0);
+ auto log_ = builder.Log(input);
+ builder.Reduce(log_, zero, add_f32, /*dimensions_to_reduce=*/{0});
+
+ Array2D<float> input_data(rows, cols);
+ input_data.FillRandom(3.14f, 0.04);
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR2FromArray2D(input_data);
+ input_literal =
+ LiteralUtil::Relayout(*input_literal, LayoutUtil::MakeLayout({0, 1}));
+ std::unique_ptr<GlobalData> input_global_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ std::vector<float> expected;
+ for (int64 colno = 0; colno < cols; ++colno) {
+ float column_sum = 0;
+ for (int64 rowno = 0; rowno < rows; ++rowno) {
+ column_sum += log(input_data(rowno, colno));
+ }
+ expected.push_back(column_sum);
+ }
+ ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
+ ErrorSpec(0.01, 1e-4));
+}
+
+struct BoundsLayout {
+ std::vector<int64> bounds;
+ std::vector<int64> layout;
+ std::vector<int64> reduce_dims;
+};
+
+void PrintTo(const BoundsLayout& spec, std::ostream* os) {
+ *os << tensorflow::strings::Printf(
+ "R%luToR%lu%s_%s_Reduce%s", spec.bounds.size(),
+ spec.bounds.size() - spec.reduce_dims.size(),
+ tensorflow::str_util::Join(spec.bounds, "x").c_str(),
+ tensorflow::str_util::Join(spec.layout, "").c_str(),
+ tensorflow::str_util::Join(spec.reduce_dims, "").c_str());
+}
+
+// Add-reduces a broadcasted scalar matrix among dimension 1 and 0.
+XLA_TEST_F(ReduceTest, AddReduce2DScalarToR0) {
+ ComputationBuilder builder(client_, TestName());
+ auto add = CreateScalarAddComputation(F32, &builder);
+ auto scalar = builder.ConstantR0<float>(42.0);
+ auto broacasted = builder.Broadcast(scalar, {500, 500});
+ builder.Reduce(broacasted, builder.ConstantR0<float>(0.0f), add, {0, 1});
+
+ float expected = 42.0f * static_cast<float>(500 * 500);
+ ComputeAndCompareR0<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+// Max-reduces a broadcasted scalar matrix among dimension 1 and 0.
+XLA_TEST_F(ReduceTest, MaxReduce2DScalarToR0) {
+ ComputationBuilder builder(client_, TestName());
+ auto max = CreateScalarMaxComputation(F32, &builder);
+ auto scalar = builder.ConstantR0<float>(42.0);
+ auto broacasted = builder.Broadcast(scalar, {500, 500});
+ builder.Reduce(broacasted, builder.ConstantR0<float>(0.0f), max, {0, 1});
+
+ float expected = 42.0f;
+ ComputeAndCompareR0<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+// Max-reduces a matrix among dimension 1 and 0.
+XLA_TEST_F(ReduceTest, MaxReduce2DToR0) {
+ ComputationBuilder builder(client_, TestName());
+ auto max = CreateScalarMaxComputation(F32, &builder);
+ Array2D<float> input(300, 250);
+ input.FillRandom(214.0f);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
+ builder.Reduce(builder.ConstantLiteral(*input_literal),
+ builder.ConstantR0<float>(FLT_MIN), max, {0, 1});
+ auto input_max = FLT_MIN;
+ input.Each(
+ [&](int64, int64, float* v) { input_max = std::max(input_max, *v); });
+ ComputeAndCompareR0<float>(&builder, input_max, {}, ErrorSpec(0.0001));
+}
+
+// Min-reduces matrix among dimension 1 and 0.
+XLA_TEST_F(ReduceTest, MinReduce2DToR0) {
+ ComputationBuilder builder(client_, TestName());
+ auto min = CreateScalarMinComputation(F32, &builder);
+ Array2D<float> input(150, 130);
+ input.FillRandom(214.0f);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
+ builder.Reduce(builder.ConstantLiteral(*input_literal),
+ builder.ConstantR0<float>(FLT_MAX), min, {0, 1});
+
+ auto input_min = FLT_MAX;
+ input.Each(
+ [&](int64, int64, float* v) { input_min = std::min(input_min, *v); });
+ ComputeAndCompareR0<float>(&builder, input_min, {}, ErrorSpec(0.0001));
+}
+
+// Reduces a matrix among dimension 1.
+XLA_TEST_F(ReduceTest, Reduce2DAmong1) {
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantLiteral(*literal_2d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {1});
+
+ std::vector<float> expected = {6.f, 15.f};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) {
+ // Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar).
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantLiteral(*literal_2d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0, 1});
+
+ ComputeAndCompareR0<float>(&builder, 21.0f, {}, ErrorSpec(0.0001, 1e-4));
+}
+
+// Tests 2D matrix ReduceToRow operation.
+XLA_TEST_F(ReduceTest, Reduce2DAmongY) {
+ ComputationBuilder builder(client_, "reduce_among_y");
+ auto m = builder.ConstantLiteral(*literal_2d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0});
+
+ std::vector<float> expected = {5.f, 7.f, 9.f};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) {
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantLiteral(*literal_3d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {1, 2});
+
+ std::vector<float> expected = {21.f, 21.f, 21.f, 21.f};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) {
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantLiteral(*literal_3d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0, 1});
+
+ std::vector<float> expected = {20.f, 28.f, 36.f};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceTest, ReduceR3ToR0) {
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantLiteral(*literal_3d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0, 1, 2});
+
+ float expected = 21.0f * 4.0;
+ ComputeAndCompareR0<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) {
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantLiteral(*literal_3d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0});
+
+ // clang-format off
+ Array2D<float> expected({
+ {4.f, 8.f, 12.f},
+ {16.f, 20.f, 24.f},
+ });
+ // clang-format on
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) {
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantLiteral(*literal_3d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {1});
+
+ // clang-format off
+ Array2D<float> expected({
+ {5.f, 7.f, 9.f},
+ {5.f, 7.f, 9.f},
+ {5.f, 7.f, 9.f},
+ {5.f, 7.f, 9.f},
+ });
+ // clang-format on
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) {
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantLiteral(*literal_3d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {2});
+
+ // clang-format off
+ Array2D<float> expected({
+ {6.f, 15.f},
+ {6.f, 15.f},
+ {6.f, 15.f},
+ {6.f, 15.f},
+ });
+ // clang-format on
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+class ReduceR3ToR2Test : public ReduceTest,
+ public ::testing::WithParamInterface<BoundsLayout> {};
+
+XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) {
+ ComputationBuilder builder(client_, TestName());
+ const auto& bounds = GetParam().bounds;
+ Array3D<float> input_array(bounds[0], bounds[1], bounds[2]);
+ input_array.FillRandom(3.14f, 0.05);
+
+ auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array);
+ input_literal = LiteralUtil::Relayout(
+ *input_literal, LayoutUtil::MakeLayout(GetParam().layout));
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ auto input_activations =
+ builder.Parameter(0, input_literal->shape(), "input");
+ Computation add = CreateScalarAddComputation(F32, &builder);
+ auto sum = builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f),
+ add, GetParam().reduce_dims);
+
+ auto expected =
+ ReferenceUtil::Reduce3DTo2D(input_array, 0.0f, GetParam().reduce_dims,
+ [](float a, float b) { return a + b; });
+
+ ComputeAndCompareR2<float>(&builder, *expected, {input_data.get()},
+ ErrorSpec(1e-3, 1e-3));
+}
+
+INSTANTIATE_TEST_CASE_P(
+ ReduceR3ToR2Test_Instantiation, ReduceR3ToR2Test,
+ // Specifies (shape, layout, reduction dimensions).
+ ::testing::Values(BoundsLayout{{4, 8, 128}, {2, 1, 0}, {0}},
+ BoundsLayout{{4, 8, 128}, {2, 1, 0}, {1}},
+ BoundsLayout{{4, 8, 128}, {2, 1, 0}, {2}},
+ // These should be simplified into a reshape.
+ BoundsLayout{{1, 21, 43}, {2, 1, 0}, {0}},
+ BoundsLayout{{1, 1, 1}, {2, 1, 0}, {0}},
+ BoundsLayout{{1, 1, 1}, {2, 1, 0}, {1}},
+ BoundsLayout{{1, 1, 1}, {2, 1, 0}, {2}},
+ BoundsLayout{{8, 16, 24}, {0, 1, 2}, {0}},
+ BoundsLayout{{8, 16, 24}, {0, 1, 2}, {1}},
+ BoundsLayout{{8, 16, 24}, {0, 1, 2}, {2}},
+ BoundsLayout{{5, 10, 250}, {2, 1, 0}, {0}},
+ BoundsLayout{{5, 10, 250}, {2, 1, 0}, {1}},
+ BoundsLayout{{5, 10, 250}, {2, 1, 0}, {2}},
+ BoundsLayout{{8, 16, 256}, {2, 1, 0}, {0}},
+ BoundsLayout{{8, 16, 256}, {2, 1, 0}, {1}},
+ BoundsLayout{{8, 16, 256}, {2, 1, 0}, {2}},
+ BoundsLayout{{2, 300, 784}, {2, 1, 0}, {2}},
+ BoundsLayout{{2, 300, 784}, {2, 1, 0}, {1}},
+ BoundsLayout{{2, 300, 784}, {2, 1, 0}, {0}}));
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
new file mode 100644
index 0000000000..f48c14dfc6
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -0,0 +1,445 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests the reduce-window XLA operation.
+
+#include <limits>
+#include <memory>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/padding.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ReduceWindowTest : public ClientLibraryTestBase {
+ public:
+ ReduceWindowTest() : builder_(client_, TestName()) {}
+
+ void ReduceWindowAdd(ComputationDataHandle input,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ Padding padding) {
+ builder_.ReduceWindow(input, builder_.ConstantR0<float>(0.0f),
+ CreateScalarAddComputation(F32, &builder_),
+ window_dimensions, window_strides, padding);
+ }
+
+ void ReduceWindowMax(ComputationDataHandle input,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ Padding padding) {
+ builder_.ReduceWindow(
+ input, builder_.ConstantLiteral(LiteralUtil::MinValue(F32)),
+ CreateScalarMax(), window_dimensions, window_strides, padding);
+ }
+
+ void ReduceWindowMin(ComputationDataHandle input,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ Padding padding) {
+ builder_.ReduceWindow(input,
+ builder_.ConstantLiteral(LiteralUtil::MaxValue(F32)),
+ CreateScalarMinComputation(F32, &builder_),
+ window_dimensions, window_strides, padding);
+ }
+
+ ComputationBuilder builder_;
+};
+
+XLA_TEST_F(ReduceWindowTest, ZeroElementSmall) {
+ Array4D<float> input_array(1, 0, 2, 1);
+
+ const auto input = builder_.ConstantR4FromArray4D<float>(input_array);
+ Padding padding = Padding::kSame;
+ ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding);
+
+ auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
+ {1, 1, 1, 1}, padding);
+
+ ComputeAndCompareR4<float>(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3));
+}
+
+TEST_F(ReduceWindowTest, NonSquareSmall) {
+ Array4D<float> input_array(1, 2, 2, 1);
+ input_array.FillRandom(2.f);
+
+ const auto input = builder_.ConstantR4FromArray4D<float>(input_array);
+ Padding padding = Padding::kSame;
+ ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding);
+
+ auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
+ {1, 1, 1, 1}, padding);
+
+ ComputeAndCompareR4<float>(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3));
+}
+
+TEST_F(ReduceWindowTest, MiddleDimsSmall) {
+ Array4D<float> input_array(1, 3, 3, 1);
+ input_array.FillRandom(2.f);
+
+ const auto input = builder_.ConstantR4FromArray4D<float>(input_array);
+ Padding padding = Padding::kSame;
+ ReduceWindowAdd(input, {1, 1, 1, 1}, {1, 2, 2, 1}, padding);
+
+ auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1},
+ {1, 2, 2, 1}, padding);
+
+ ComputeAndCompareR4<float>(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3));
+}
+
+TEST_F(ReduceWindowTest, Along2ndMinorDim) {
+ Array4D<float> input_array(3, 6, 7, 32);
+ input_array.FillRandom(2.f);
+
+ // The parameters of this reduction mimic feature norm (e.g. LRN).
+ int lrn_diameter = 7; // diameter = 2*radius + 1 --> must be odd
+ const auto input = builder_.ConstantR4FromArray4D<float>(input_array);
+ Padding padding = Padding::kSame;
+ ReduceWindowAdd(input, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
+
+ auto res = ReferenceUtil::ReduceWindow4DAdd(
+ input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
+
+ ComputeAndCompareR4<float>(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3));
+}
+
+TEST_F(ReduceWindowTest, AmongMajor2DimsMediumSize) {
+ Array4D<float> input_array(9, 12, 4, 89);
+ input_array.FillRandom(2.0f);
+
+ int win_len = 3;
+ int win_stride = 2;
+
+ const auto input_data_handle =
+ builder_.ConstantR4FromArray4D<float>(input_array);
+
+ Padding padding = Padding::kSame;
+ // Reduce only along the x and y dimensions, according to the win_len.
+ ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
+ {win_stride, win_stride, 1, 1}, padding);
+
+ auto result = ReferenceUtil::ReduceWindow4DAdd(
+ input_array, 0.0f, {win_len, win_len, 1, 1},
+ {win_stride, win_stride, 1, 1}, padding);
+
+ ComputeAndCompareR4<float>(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3));
+}
+
+// TODO(b/32173947): Test support for arbitrary-sized padding.
+TEST_F(ReduceWindowTest, DISABLED_AmongMajor2DimsMediumSizeLargePadding) {
+ Array4D<float> input_array(9, 12, 4, 89); // simulate Dim0IsMinor layout
+ input_array.FillRandom(2.0f);
+
+ int64 rank = 4;
+ int win_len = 3;
+ int win_stride = 2;
+
+ const auto input_data_handle =
+ builder_.ConstantR4FromArray4D<float>(input_array);
+
+ Padding padding = Padding::kSame;
+ // Reduce only along the x and y dimensions, according to the win_len.
+ // Create padding vector with large padding values in the reduction dims.
+ std::vector<std::pair<int64, int64>> low_high_padding;
+ low_high_padding.resize(rank, {4, 4});
+
+ builder_.ReduceWindowWithGeneralPadding(
+ input_data_handle, builder_.ConstantR0<float>(0.0f),
+ CreateScalarAddComputation(F32, &builder_), {win_len, win_len, 1, 1},
+ {win_stride, win_stride, 1, 1}, low_high_padding);
+
+ auto result = ReferenceUtil::ReduceWindow4DAdd(
+ input_array, 0.0f, {win_len, win_len, 1, 1},
+ {win_stride, win_stride, 1, 1}, padding);
+
+ ComputeAndCompareR4<float>(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3));
+}
+// TODO(b/31809540): Implement minor dim reduction to reduce num of reshapes.
+TEST_F(ReduceWindowTest, ReduceR4AmongXYMinorSmall) {
+ Array4D<float> input_array(2, 2, 4, 16);
+
+ Array2D<float> yx({{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f,
+ 11.f, 12.f, 13.f, 14.f, 15.f},
+ {16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f,
+ 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f},
+ {32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f,
+ 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f},
+ {48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f,
+ 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f}});
+ input_array.FillWithYX(yx);
+
+ int win_len = 2;
+ int win_stride = 2;
+ const auto input = builder_.ConstantR4FromArray4D<float>(input_array);
+ Padding padding = Padding::kValid;
+ ReduceWindowAdd(input, {1, 1, win_len, win_len},
+ {1, 1, win_stride, win_stride}, padding);
+
+ auto res = ReferenceUtil::ReduceWindow4DAdd(
+ input_array, 0.0f, {1, 1, win_len, win_len},
+ {1, 1, win_stride, win_stride}, padding);
+ ComputeAndCompareR4<float>(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3));
+}
+
+// TODO(b/31809540): Implement minor dim reduction to reduce num of reshapes.
+TEST_F(ReduceWindowTest, ReduceR4AmongXYMinorSmallOverlapped) {
+ constexpr int64 p = 2;
+ constexpr int64 z = 2;
+ constexpr int64 y = 4;
+ constexpr int64 x = 16;
+ Array4D<float> input_array(p, z, y, x);
+
+ Array2D<float> yx({{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f,
+ 11.f, 12.f, 13.f, 14.f, 15.f},
+ {16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f,
+ 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f},
+ {32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f,
+ 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f},
+ {48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f,
+ 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f}});
+ input_array.FillWithYX(yx);
+
+ int win_len = 4;
+ int win_stride = 2;
+ const auto input = builder_.ConstantR4FromArray4D<float>(input_array);
+ ReduceWindowAdd(input, {1, 1, win_len, win_len},
+ {1, 1, win_stride, win_stride}, Padding::kValid);
+
+ // Expected result
+ Array2D<float> yx_result({{408.f, 440.f, 472.f, 504.f, 536.f, 568.f, 600.f}});
+ Array4D<float> expected(p, z, 1, 7);
+ expected.FillWithYX(yx_result);
+ ComputeAndCompareR4<float>(&builder_, expected, {}, ErrorSpec(1e-3, 1e-3));
+}
+
+TEST_F(ReduceWindowTest, MaxTrivial) {
+ const auto input = builder_.ConstantR1<float>({42});
+ ReduceWindowMax(input, {1}, {1}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {42}, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Add3In3) {
+ const auto input = builder_.ConstantR1<float>({20, 100, 3});
+ ReduceWindowAdd(input, {3}, {1}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {123}, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Add4In16Stride4) {
+ const auto input = builder_.ConstantR1<float>(
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
+ ReduceWindowAdd(input, {4}, {4}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {10, 26, 42, 58}, {},
+ ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, DISABLED_ON_CPU(DISABLED_ON_GPU(Min3In5Stride2))) {
+ const auto input = builder_.ConstantR1<float>({10000, 1000, 100, 10, 1});
+ ReduceWindowMin(input, {3}, {2}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {100, 1}, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Max3In3) {
+ const auto input = builder_.ConstantR1<float>({20, 100, 3});
+ ReduceWindowMax(input, {3}, {1}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {100}, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Add2In3) {
+ const auto input = builder_.ConstantR1<float>({100, 10, 1});
+ ReduceWindowAdd(input, {2}, {1}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {110, 11}, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Add3In5Stride2) {
+ const auto input = builder_.ConstantR1<float>({10000, 1000, 100, 10, 1});
+ ReduceWindowAdd(input, {3}, {2}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {11100, 111}, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Max4In16Stride4) {
+ const auto input = builder_.ConstantR1<float>(
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
+ ReduceWindowMax(input, {4}, {4}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {4, 8, 12, 16}, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Max4In16Stride3) {
+ const auto input = builder_.ConstantR1<float>(
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
+ ReduceWindowMax(input, {4}, {3}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {4, 7, 10, 13, 16}, {},
+ ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Max4In16Stride8) {
+ const auto input = builder_.ConstantR1<float>(
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
+ ReduceWindowMax(input, {4}, {8}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {4, 12}, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Max3In5Stride2) {
+ const auto input = builder_.ConstantR1<float>({10000, 1000, 100, 10, 1});
+ ReduceWindowMax(input, {3}, {2}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {10000, 100}, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Max3In5Stride1) {
+ const auto input = builder_.ConstantR1<float>({10000, 1000, 100, 10, 101});
+ ReduceWindowMax(input, {3}, {1}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {10000, 1000, 101}, {},
+ ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Add3In4Stride2) {
+ const auto input = builder_.ConstantR1<float>({1000, 100, 10, 1});
+ ReduceWindowAdd(input, {3}, {2}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {1110}, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceWindowTest, Add2In3SamePad) {
+ const auto input = builder_.ConstantR1<float>({100, 10, 1});
+ ReduceWindowAdd(input, {2}, {1}, Padding::kSame);
+ ComputeAndCompareR1<float>(&builder_, {110, 11, 1}, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceWindowTest, Add3In3SamePad) {
+ const auto input = builder_.ConstantR1<float>({100, 10, 1});
+ ReduceWindowAdd(input, {3}, {1}, Padding::kSame);
+ ComputeAndCompareR1<float>(&builder_, {110, 111, 11}, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceWindowTest, Add3In3Stride3SamePad) {
+ const auto input = builder_.ConstantR1<float>({100, 10, 1});
+ ReduceWindowAdd(input, {3}, {2}, Padding::kSame);
+ ComputeAndCompareR1<float>(&builder_, {110, 11}, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Add2x2In2x2Overlapped) {
+ Array2D<float> input_array({{1.2f, -2.5f, 0.9f, 1.0f},
+ {3.7f, 0.2f, -1.0f, -0.2f},
+ {-0.4f, 2.7f, 1.1f, 2.2f},
+ {0.6f, 1.7f, 1.4f, -0.2f}});
+ auto input = builder_.ConstantR2FromArray2D<float>(input_array);
+ ReduceWindowAdd(input, {2, 2}, {1, 1}, Padding::kValid);
+ Array2D<float> expected(
+ {{2.6f, -2.4f, 0.7f}, {6.2f, 3.0f, 2.1f}, {4.6f, 6.9f, 4.5f}});
+ ComputeAndCompareR2<float>(&builder_, expected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Add2x2In2x2Disjoint) {
+ Array2D<float> input_array({{1.2f, -2.5f, 0.9f, 1.0f},
+ {3.7f, 0.2f, -1.0f, -0.2f},
+ {-0.4f, 2.7f, 1.1f, 2.2f},
+ {0.6f, 1.7f, 1.4f, -0.2f}});
+ auto input = builder_.ConstantR2FromArray2D<float>(input_array);
+ ReduceWindowAdd(input, {2, 2}, {2, 2}, Padding::kValid);
+ Array2D<float> expected({
+ {2.6f, 0.7f}, {4.6f, 4.5f},
+ });
+ ComputeAndCompareR2<float>(&builder_, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x2) {
+ Array3D<float> input_array(2, 1, 2);
+ input_array(0, 0, 0) = 1000;
+ input_array(0, 0, 1) = 100;
+ input_array(1, 0, 0) = 10;
+ input_array(1, 0, 1) = 1;
+ auto input = builder_.ConstantR3FromArray3D<float>(input_array);
+
+ ReduceWindowAdd(input, {1, 1, 2}, {1, 1, 1}, Padding::kValid);
+
+ Array3D<float> expected(2, 1, 1);
+ expected(0, 0, 0) = 1100;
+ expected(1, 0, 0) = 11;
+ ComputeAndCompareR3<float>(&builder_, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3Stride1x1x2) {
+ Array3D<float> input_array(2, 1, 3);
+ input_array(0, 0, 0) = 100;
+ input_array(0, 0, 1) = 10;
+ input_array(0, 0, 2) = 1;
+ input_array(1, 0, 0) = 500;
+ input_array(1, 0, 1) = 50;
+ input_array(1, 0, 2) = 5;
+ auto input = builder_.ConstantR3FromArray3D<float>(input_array);
+
+ ReduceWindowAdd(input, {1, 1, 2}, {1, 1, 2}, Padding::kValid);
+
+ Array3D<float> expected(2, 1, 1);
+ expected(0, 0, 0) = 110;
+ expected(1, 0, 0) = 550;
+ ComputeAndCompareR3<float>(&builder_, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3SamePad) {
+ Array3D<float> input_array(2, 1, 3);
+ input_array(0, 0, 0) = 100;
+ input_array(0, 0, 1) = 10;
+ input_array(0, 0, 2) = 1;
+ input_array(1, 0, 0) = 500;
+ input_array(1, 0, 1) = 50;
+ input_array(1, 0, 2) = 5;
+ auto input = builder_.ConstantR3FromArray3D<float>(input_array);
+
+ ReduceWindowAdd(input, {1, 1, 2}, {1, 1, 1}, Padding::kSame);
+
+ Array3D<float> expected(2, 1, 3);
+ expected(0, 0, 0) = 110;
+ expected(0, 0, 1) = 11;
+ expected(0, 0, 2) = 1;
+ expected(1, 0, 0) = 550;
+ expected(1, 0, 1) = 55;
+ expected(1, 0, 2) = 5;
+ ComputeAndCompareR3<float>(&builder_, expected, {}, ErrorSpec(0.0001));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc
new file mode 100644
index 0000000000..802087b508
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/replay_test.cc
@@ -0,0 +1,168 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/protobuf_util.h"
+#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ReplayTest : public ClientLibraryTestBase {};
+
+TEST_F(ReplayTest, TwoPlusTwoReplay) {
+ // Make 2+2 computation.
+ ComputationBuilder builder(client_, TestName());
+ auto two = builder.ConstantR0<int32>(2);
+ builder.Add(two, two);
+ Computation computation = builder.Build().ConsumeValueOrDie();
+
+ // Serialize it out.
+ std::unique_ptr<SessionModule> module =
+ computation.Snapshot().ConsumeValueOrDie();
+
+ // Replay it.
+ Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie();
+
+ // Check signature is the same.
+ std::unique_ptr<ProgramShape> original_shape =
+ client_->GetComputationShape(computation).ConsumeValueOrDie();
+ std::unique_ptr<ProgramShape> replayed_shape =
+ client_->GetComputationShape(replayed).ConsumeValueOrDie();
+ ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
+
+ // Run it.
+ std::unique_ptr<Literal> literal =
+ client_->ExecuteAndTransfer(replayed, /*arguments=*/{})
+ .ConsumeValueOrDie();
+
+ // Expect 4.
+ LiteralTestUtil::ExpectR0Equal<int32>(4, *literal);
+}
+
+XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
+ // Make computation.
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x");
+ auto y = builder.Parameter(1, ShapeUtil::MakeShape(S32, {}), "y");
+ builder.Add(x, y);
+ Computation computation = builder.Build().ConsumeValueOrDie();
+
+ // Serialize it out.
+ std::unique_ptr<SessionModule> module =
+ computation.Snapshot().ConsumeValueOrDie();
+
+ // Replay it.
+ Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie();
+
+ // Check signature is the same.
+ std::unique_ptr<ProgramShape> original_shape =
+ client_->GetComputationShape(computation).ConsumeValueOrDie();
+ std::unique_ptr<ProgramShape> replayed_shape =
+ client_->GetComputationShape(replayed).ConsumeValueOrDie();
+ ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
+
+ // Run it.
+ std::unique_ptr<GlobalData> x_data =
+ client_->TransferToServer(*LiteralUtil::CreateR0<int32>(2))
+ .ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> y_data =
+ client_->TransferToServer(*LiteralUtil::CreateR0<int32>(3))
+ .ConsumeValueOrDie();
+ std::unique_ptr<Literal> literal =
+ client_
+ ->ExecuteAndTransfer(replayed,
+ /*arguments=*/{x_data.get(), y_data.get()})
+ .ConsumeValueOrDie();
+
+ // Expect 5.
+ LiteralTestUtil::ExpectR0Equal<int32>(5, *literal);
+}
+
+TEST_F(ReplayTest, MapPlusTwoOverR1) {
+ // As above, but with map(+2) over some constant array.
+ ComputationBuilder plus_two_builder(client_, "plus two");
+ auto input =
+ plus_two_builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "input");
+ plus_two_builder.Add(input, plus_two_builder.ConstantR0<int32>(2));
+ Computation plus_two = plus_two_builder.Build().ConsumeValueOrDie();
+
+ ComputationBuilder mapper_builder(client_, TestName());
+ auto original = mapper_builder.ConstantR1<int32>({1, 2, 3});
+ mapper_builder.Map({original}, plus_two);
+
+ Computation computation = mapper_builder.Build().ConsumeValueOrDie();
+
+ // Serialize it out.
+ std::unique_ptr<SessionModule> module =
+ computation.Snapshot().ConsumeValueOrDie();
+
+ // Replay it.
+ Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie();
+
+ // Check signature is the same.
+ std::unique_ptr<ProgramShape> original_shape =
+ client_->GetComputationShape(computation).ConsumeValueOrDie();
+ std::unique_ptr<ProgramShape> replayed_shape =
+ client_->GetComputationShape(replayed).ConsumeValueOrDie();
+ ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
+
+ // Destroy the originals.
+ computation.Reset();
+ plus_two.Reset();
+
+ // Run it.
+ std::unique_ptr<Literal> literal =
+ client_->ExecuteAndTransfer(replayed, /*arguments=*/{})
+ .ConsumeValueOrDie();
+
+ // Expect result.
+ LiteralTestUtil::ExpectR1Equal<int32>({3, 4, 5}, *literal);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc
new file mode 100644
index 0000000000..ce309eb743
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc
@@ -0,0 +1,77 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <numeric>
+#include <random>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+using ReshapeMotionTest = ClientLibraryTestBase;
+
+TEST_F(ReshapeMotionTest, ElementwiseOfReshapesWithNonSameInputShapes) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2<int32>({{2, 3, 5}, {7, 11, 13}});
+ auto b = builder.ConstantR2<int32>({{17, 19}, {23, 29}, {31, 37}});
+ auto c = builder.Reshape(a, {6});
+ auto d = builder.Reshape(b, {6});
+ auto e = builder.Mul(c, d);
+
+ ComputeAndCompareR1<int32>(&builder, {34, 57, 115, 203, 341, 481}, {});
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc
new file mode 100644
index 0000000000..a9159d39ca
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/reshape_test.cc
@@ -0,0 +1,811 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <numeric>
+#include <random>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ReshapeTest : public ClientLibraryTestBase {
+ public:
+ ErrorSpec zero_error_spec_{0.0};
+};
+
+// Collapses 2-dimensional pseudo-scalar (single-element array) to 1 dimension.
+XLA_TEST_F(ReshapeTest, Trivial1x1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2<float>({{1.0}});
+ builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1});
+
+ ComputeAndCompareR1<float>(&builder, {1.0f}, {}, zero_error_spec_);
+}
+
+// Collapses 2-dimensional pseudo-scalar (single-element array) to scalar.
+XLA_TEST_F(ReshapeTest, SingleElementArrayToScalar) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2<float>({{1.0}});
+ auto reshape =
+ builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, /*new_sizes=*/{});
+ auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie();
+
+ ComputeAndCompareR0<float>(&builder, 1.0f, {}, zero_error_spec_);
+}
+
+XLA_TEST_F(ReshapeTest, Trivial0x3) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 3));
+ auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1});
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, zero_error_spec_);
+}
+
+XLA_TEST_F(ReshapeTest, Trivial3x0) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2FromArray2D<float>(Array2D<float>(3, 0));
+ auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1});
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, zero_error_spec_);
+}
+
+// Collapses a 2-dimensional row vector to 1 dimension.
+XLA_TEST_F(ReshapeTest, Trivial1x3) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2<float>({{1.0f, 2.0f, 3.0f}});
+ auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1});
+
+ ComputeAndCompareR1<float>(&builder, {1.0f, 2.0f, 3.0f}, {},
+ zero_error_spec_);
+}
+
+// Collapses a 2-dimensional column vector to 1 dimension.
+XLA_TEST_F(ReshapeTest, Trivial3x1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2<float>({{1.0f}, {2.0f}, {3.0f}});
+ auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1});
+
+ ComputeAndCompareR1<float>(&builder, {1.0f, 2.0f, 3.0f}, {},
+ zero_error_spec_);
+}
+
+// Splits an empty vector into an empty matrix.
+XLA_TEST_F(ReshapeTest, R1ToR2_0_To_2x0) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({});
+ auto result =
+ builder.Reshape(/*operand=*/a, /*dimensions=*/{0}, /*new_sizes=*/{2, 0});
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(2, 0), {},
+ zero_error_spec_);
+}
+
+// Splits a vector into a matrix.
+XLA_TEST_F(ReshapeTest, R1ToR2_6_To_2x3) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
+ auto result =
+ builder.Reshape(/*operand=*/a, /*dimensions=*/{0}, /*new_sizes=*/{2, 3});
+ Array2D<float> expected_2x3({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
+ ComputeAndCompareR2<float>(&builder, expected_2x3, {}, zero_error_spec_);
+}
+
+// Transposes a 2x0 array to a 0x2 array.
+XLA_TEST_F(ReshapeTest, Reshape0x2To2x0) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
+ auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1},
+ /*new_sizes=*/{2, 0});
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(2, 0), {},
+ zero_error_spec_);
+}
+
+// Transposes a 2-dimensional row vector to a column vector.
+XLA_TEST_F(ReshapeTest, ReshapeRowToCol) {
+ ComputationBuilder builder(client_, TestName());
+ auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3);
+ auto a = builder.ConstantR2FromArray2D<float>(*simple);
+ auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1},
+ /*new_sizes=*/{3, 1});
+
+ auto expected = ReferenceUtil::TransposeArray2D(*simple);
+ ComputeAndCompareR2<float>(&builder, *expected, {}, zero_error_spec_);
+}
+
+// Transposes a 2-dimensional array.
+XLA_TEST_F(ReshapeTest, TransposeAsReshape) {
+ ComputationBuilder builder(client_, TestName());
+ auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
+ auto a = builder.ConstantR2FromArray2D<float>(*a4x3);
+ auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{1, 0},
+ /*new_sizes=*/{3, 4});
+
+ auto expected3x4 = ReferenceUtil::TransposeArray2D(*a4x3);
+ ComputeAndCompareR2<float>(&builder, *expected3x4, {}, zero_error_spec_);
+}
+
+// Transposes a 0x4 array with ComputationBuilder::Trans.
+XLA_TEST_F(ReshapeTest, Transpose0x4) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 4));
+ auto result = builder.Transpose(a, {1, 0});
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(4, 0), {},
+ zero_error_spec_);
+}
+
+// Transposes a 2-dimensional array with ComputationBuilder::Trans.
+XLA_TEST_F(ReshapeTest, Transpose4x3) {
+ ComputationBuilder builder(client_, TestName());
+ auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
+ auto a = builder.ConstantR2FromArray2D<float>(*a4x3);
+ auto result = builder.Transpose(a, {1, 0});
+
+ auto expected3x4 = ReferenceUtil::TransposeArray2D(*a4x3);
+ ComputeAndCompareR2<float>(&builder, *expected3x4, {}, zero_error_spec_);
+}
+
+// Reshapes an empty 2-dimensional array with dimensions that are not just a
+// rearrangement of the originals (split), but no reordering (no shuffle).
+XLA_TEST_F(ReshapeTest, ReshapeSplitNoShuffleZeroElements) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2FromArray2D<float>(Array2D<float>(6, 0));
+ auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1},
+ /*new_sizes=*/{2, 3, 0, 0});
+
+ ComputeAndCompareR4<float>(&builder, Array4D<float>(2, 3, 0, 0), {},
+ zero_error_spec_);
+}
+
+XLA_TEST_F(ReshapeTest, ReshapeR4ToR2ZeroElements) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR4FromArray4D<float>(Array4D<float>(2, 3, 4, 0));
+ auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1, 2, 3},
+ /*new_sizes=*/{24, 0});
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(24, 0), {},
+ zero_error_spec_);
+}
+
+// Reshapes a 2-dimensional array with dimensions that are not just a
+// rearrangement of the originals (split), but no reordering (no shuffle).
+XLA_TEST_F(ReshapeTest, ReshapeSplitNoShuffle) {
+ ComputationBuilder builder(client_, TestName());
+ auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
+ auto a = builder.ConstantR2FromArray2D<float>(*a4x3);
+ auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1},
+ /*new_sizes=*/{2, 6});
+
+ auto expected2x6 = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6);
+ ComputeAndCompareR2<float>(&builder, *expected2x6, {}, zero_error_spec_);
+}
+
+// Reshapes a 2-dimensional array with dimensions that are not just a
+// rearrangement of the originals (split), and reorder the input (shuffle).
+XLA_TEST_F(ReshapeTest, ReshapeSplitAndShuffleZeroElements) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 6));
+ auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{1, 0},
+ /*new_sizes=*/{3, 0});
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(3, 0), {},
+ zero_error_spec_);
+}
+
+// Reshapes a 2-dimensional array with dimensions that are not just a
+// rearrangement of the originals (split), and reorder the input (shuffle).
+XLA_TEST_F(ReshapeTest, ReshapeSplitAndShuffle) {
+ ComputationBuilder builder(client_, TestName());
+ auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
+ auto a = builder.ConstantR2FromArray2D<float>(*a4x3);
+ auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{1, 0},
+ /*new_sizes=*/{2, 6});
+
+ Array2D<float> expected2x6({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f},
+ {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}});
+ ComputeAndCompareR2<float>(&builder, expected2x6, {}, zero_error_spec_);
+}
+
+// The following tests use the same input 3D array; they test the examples we
+// show for the Reshape operation in the operation_semantics document.
+// TODO(eliben): find a way to show this code in the documentation without
+// duplication.
+Array3D<int> v_array_for_doc_R3_tests({{{10, 11, 12}, {15, 16, 17}},
+ {{20, 21, 22}, {25, 26, 27}},
+ {{30, 31, 32}, {35, 36, 37}},
+ {{40, 41, 42}, {45, 46, 47}}});
+
+XLA_TEST_F(ReshapeTest, DocR3_R1_Collapse_012) {
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR3FromArray3D<int>(v_array_for_doc_R3_tests);
+ auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{0, 1, 2},
+ /*new_sizes=*/{24});
+ ComputeAndCompareR1<int>(&builder,
+ {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
+ 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47},
+ {});
+}
+
+XLA_TEST_F(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) {
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR3FromArray3D<int>(v_array_for_doc_R3_tests);
+ auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{0, 1, 2},
+ /*new_sizes=*/{8, 3});
+ Array2D<int> expected({{10, 11, 12},
+ {15, 16, 17},
+ {20, 21, 22},
+ {25, 26, 27},
+ {30, 31, 32},
+ {35, 36, 37},
+ {40, 41, 42},
+ {45, 46, 47}});
+ ComputeAndCompareR2<int>(&builder, expected, {});
+}
+
+XLA_TEST_F(ReshapeTest, DocR3_R1_Collapse_120) {
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR3FromArray3D<int>(v_array_for_doc_R3_tests);
+ auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{1, 2, 0},
+ /*new_sizes=*/{24});
+ ComputeAndCompareR1<int>(&builder,
+ {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42,
+ 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47},
+ {});
+}
+
+XLA_TEST_F(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) {
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR3FromArray3D<int>(v_array_for_doc_R3_tests);
+ auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{1, 2, 0},
+ /*new_sizes=*/{8, 3});
+ Array2D<int> expected({{10, 20, 30},
+ {40, 11, 21},
+ {31, 41, 12},
+ {22, 32, 42},
+ {15, 25, 35},
+ {45, 16, 26},
+ {36, 46, 17},
+ {27, 37, 47}});
+ ComputeAndCompareR2<int>(&builder, expected, {});
+}
+
+XLA_TEST_F(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) {
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR3FromArray3D<int>(v_array_for_doc_R3_tests);
+ auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{1, 2, 0},
+ /*new_sizes=*/{2, 6, 2});
+ Array3D<int> expected(
+ {{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}},
+ {{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}});
+ ComputeAndCompareR3<int>(&builder, expected, {});
+}
+
+// Collapses the low dimensions of a 4D tensor to get a 2D matrix, without
+// reordering dimensions (for NeuralNet::FullyConnected).
+//
+// First we create a tesseract raster-face like:
+//
+// 1 2 3
+// 4 5 6
+//
+// First we collapse Y and X within the raster space yielding:
+//
+// 1 2 3 4 5 6
+//
+// Then we collapse Z be collapsed so we just end up with planes:
+//
+// 1 2 3 4 5 6 1 2 3 4 5 6
+XLA_TEST_F(ReshapeTest, FullyConnectedCollapse) {
+ ComputationBuilder builder(client_, TestName());
+ Array4D<float> t2x2x2x3(2, 2, 2, 3);
+ auto filler2x3 = MakeLinspaceArray2D(1.0f, 6.0f, 2, 3);
+ t2x2x2x3.FillWithYX(*filler2x3);
+ auto a = builder.ConstantR4FromArray4D<float>(t2x2x2x3);
+ auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{1, 2, 3});
+
+ Array2D<float> expected2x12(
+ {{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
+ 6.0f}});
+ ComputeAndCompareR2<float>(&builder, expected2x12, {}, zero_error_spec_);
+}
+
+// As above, but uses reshape directly.
+XLA_TEST_F(ReshapeTest, FullyConnectedCollapseDesugared) {
+ ComputationBuilder builder(client_, TestName());
+ Array4D<float> t(2, 1, 2, 2);
+ t(0, 0, 0, 0) = 0;
+ t(0, 0, 0, 1) = 1;
+ t(0, 0, 1, 0) = 2;
+ t(0, 0, 1, 1) = 3;
+ t(1, 0, 0, 0) = 4;
+ t(1, 0, 0, 1) = 5;
+ t(1, 0, 1, 0) = 6;
+ t(1, 0, 1, 1) = 7;
+ auto a = builder.ConstantR4FromArray4D<float>(t);
+ auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1, 2, 3},
+ /*new_sizes=*/{2, 4});
+
+ Array2D<float> expected({{0, 1, 2, 3}, {4, 5, 6, 7}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, zero_error_spec_);
+}
+
+// Reshape various ranks to a scalar.
+XLA_TEST_F(ReshapeTest, ToScalar) {
+ for (int rank = 0; rank < 8; ++rank) {
+ ComputationBuilder b(client_, TestName());
+ auto input = LiteralUtil::CreateR1<float>({83.0f});
+ std::vector<int64> ones(rank, 1); // this is {1, ..., 1}.
+ std::vector<int64> dimensions(rank);
+ std::iota(dimensions.begin(), dimensions.end(), 0);
+ *input->mutable_shape() = ShapeUtil::MakeShape(F32, ones);
+ b.Reshape(b.ConstantLiteral(*input), dimensions, {});
+
+ ComputeAndCompareR0<float>(&b, 83.0f, {}, zero_error_spec_);
+ }
+}
+
+XLA_TEST_F(ReshapeTest, BadDimensions) {
+ ComputationBuilder b(client_, TestName());
+ b.Reshape(b.ConstantR1<int32>({1}), {}, {});
+ EXPECT_MATCH(ExecuteToString(&b, {}),
+ testing::HasSubstr("dimensions not a permutation"));
+}
+
+XLA_TEST_F(ReshapeTest, BadNewSizes) {
+ ComputationBuilder b(client_, TestName());
+ b.Reshape(b.ConstantR1<int32>({1, 2}), {1}, {});
+ EXPECT_MATCH(ExecuteToString(&b, {}),
+ testing::HasSubstr("mismatched element counts"));
+}
+
+XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
+ const Shape parameter_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2});
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.Parameter(0, parameter_shape, "a");
+ builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8});
+
+ // clang-format off
+ auto literal = LiteralUtil::CreateR4FromArray4DWithLayout(Array4D<float>{
+ {
+ {
+ {0, 1},
+ {2, 3},
+ },
+ {
+ {100, 101},
+ {102, 103},
+ },
+ },
+ {
+ {
+ {222, 333},
+ {444, 555},
+ },
+ {
+ {666, 777},
+ {888, 999},
+ },
+ },
+ },
+ LayoutUtil::MakeLayout({0, 1, 2, 3}));
+ // clang-format on
+ std::unique_ptr<GlobalData> input =
+ client_->TransferToServer(*literal).ConsumeValueOrDie();
+ Array2D<float> expected_array({
+ {0, 1, 2, 3, 100, 101, 102, 103},
+ {222, 333, 444, 555, 666, 777, 888, 999},
+ });
+
+ Computation computation = builder.Build().ConsumeValueOrDie();
+ const Shape shape_with_output_layout =
+ ShapeUtil::MakeShapeWithLayout(F32, {2, 8}, {1, 0});
+ std::unique_ptr<Literal> actual =
+ client_
+ ->ExecuteAndTransfer(computation, {input.get()},
+ &shape_with_output_layout)
+ .ConsumeValueOrDie();
+ std::unique_ptr<Literal> expected =
+ LiteralUtil::CreateR2FromArray2D<float>(expected_array);
+ LiteralTestUtil::ExpectEqual(*expected, *actual);
+}
+
+XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
+ std::unique_ptr<Literal> input = LiteralUtil::CreateR2<float>({
+ {0, 1, 2, 3, 4, 5, 6, 7},
+ {100, 101, 102, 103, 104, 105, 106, 107},
+ {200, 201, 202, 203, 204, 205, 206, 207},
+ });
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.Parameter(0, input->shape(), "a");
+ builder.Reshape(a, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4});
+
+ // clang-format off
+ Array4D<float> expected = {
+ {{{0, 1, 2, 3}},
+ {{4, 5, 6, 7}}},
+ {{{100, 101, 102, 103}},
+ {{104, 105, 106, 107}}},
+ {{{200, 201, 202, 203}},
+ {{204, 205, 206, 207}}}
+ };
+ // clang-format on
+ ComputeAndCompareR4<float>(&builder, expected, {input_data.get()},
+ zero_error_spec_);
+}
+
+// Tests R2->R4 reshape with the reshape dimensions {1, 0}.
+XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) {
+ std::unique_ptr<Literal> input = LiteralUtil::CreateR2<float>({
+ {0, 1, 2, 3, 4, 5, 6, 7},
+ {100, 101, 102, 103, 104, 105, 106, 107},
+ {200, 201, 202, 203, 204, 205, 206, 207},
+ });
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.Parameter(0, input->shape(), "a");
+ builder.Reshape(a, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4});
+
+ // clang-format off
+ Array4D<float> expected = {
+ {{{0, 100, 200, 1}},
+ {{101, 201, 2, 102}}},
+ {{{202, 3, 103, 203}},
+ {{4, 104, 204, 5}}},
+ {{{105, 205, 6, 106}},
+ {{206, 7, 107, 207}}}
+ };
+ // clang-format on
+ ComputeAndCompareR4<float>(&builder, expected, {input_data.get()},
+ zero_error_spec_);
+}
+
+XLA_TEST_F(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) {
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ Array4D<float> input(2, 1, 1, 1);
+ input.Each(
+ [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.Parameter(0, input_literal->shape(), "a");
+ builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1});
+
+ std::unique_ptr<Literal> expected =
+ LiteralTestUtil::Reshape({2, 1}, {1, 0}, *input_literal);
+ ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ zero_error_spec_);
+}
+
+XLA_TEST_F(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) {
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ Array4D<float> input(2, 1, 4, 1);
+ input.Each(
+ [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.Parameter(0, input_literal->shape(), "a");
+ builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2});
+
+ std::unique_ptr<Literal> expected =
+ LiteralTestUtil::Reshape({4, 2}, {1, 0}, *input_literal);
+ ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ zero_error_spec_);
+}
+
+// Tests R4->R2 reshape with the reshape dimensions {0, 2, 1, 3}.
+XLA_TEST_F(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) {
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ Array4D<float> input(5, 10, 2, 3);
+ input.Each(
+ [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.Parameter(0, input_literal->shape(), "a");
+ builder.Reshape(a, /*dimensions=*/{0, 2, 1, 3}, /*new_sizes=*/{5, 60});
+
+ Array2D<float> expected_array(5, 60);
+ input.Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* cell) {
+ expected_array(indices[0], indices[2] * 30 + indices[1] * 3 + indices[3]) =
+ *cell;
+ });
+ auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
+ ComputeAndCompareLiteral(&builder, *expected, {input_data.get()});
+}
+
+XLA_TEST_F(ReshapeTest, NoopReshape) {
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ Array4D<float> input_array(2, 3, 5, 7);
+ input_array.Each(
+ [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR4FromArray4DWithLayout(
+ input_array, LayoutUtil::MakeLayout({1, 2, 3, 0}));
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto input = builder.Parameter(0, input_literal->shape(), "input");
+ builder.Reshape(input, /*dimensions=*/{3, 0, 1, 2},
+ /*new_sizes=*/{7, 2, 3, 5});
+ Computation computation = builder.Build().ConsumeValueOrDie();
+
+ const Shape output_shape_with_layout =
+ ShapeUtil::MakeShapeWithLayout(F32, {7, 2, 3, 5}, {2, 3, 0, 1});
+ std::unique_ptr<Literal> output_literal =
+ client_
+ ->ExecuteAndTransfer(computation, {input_data.get()},
+ &output_shape_with_layout)
+ .ConsumeValueOrDie();
+
+ // Since the reshape is a no-op, verify that it does not change the underlying
+ // data.
+ EXPECT_EQ(tensorflow::gtl::ArraySlice<float>(input_literal->f32s()),
+ tensorflow::gtl::ArraySlice<float>(output_literal->f32s()));
+}
+
+XLA_TEST_F(ReshapeTest, R4ToR4Reshape_Trivial) {
+ auto literal_1x2x3x4 = LiteralUtil::CreateR4(
+ {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
+ {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
+
+ ComputationBuilder builder(client_, TestName());
+ auto input = builder.ConstantLiteral(*literal_1x2x3x4);
+ builder.Reshape(input, /*dimensions=*/{0, 1, 2, 3},
+ /*new_sizes=*/{1, 2, 3, 4});
+
+ ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {});
+}
+
+XLA_TEST_F(ReshapeTest, R4ToR4Reshape) {
+ auto literal_1x2x3x4 = LiteralUtil::CreateR4(
+ {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
+ {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
+
+ ComputationBuilder builder(client_, TestName());
+ auto input = builder.ConstantLiteral(*literal_1x2x3x4);
+ builder.Reshape(input, /*dimensions=*/{1, 3, 2, 0},
+ /*new_sizes=*/{2, 4, 3, 1});
+
+ // clang-format off
+ auto expected_2x4x3x1 = LiteralUtil::CreateR4(
+ {{{{1}, {5}, {9}},
+ {{2}, {6}, {10}},
+ {{3}, {7}, {11}},
+ {{4}, {8}, {12}}},
+ {{{13}, {17}, {21}},
+ {{14}, {18}, {22}},
+ {{15}, {19}, {23}},
+ {{16}, {20}, {24}}}});
+ // clang-format on
+
+ ComputeAndCompareLiteral(&builder, *expected_2x4x3x1, {});
+}
+
+XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeSimple) {
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ std::vector<int64> bounds = {2, 2, 2, 2};
+ std::vector<int64> new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]};
+ Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
+ input.Each(
+ [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.Parameter(0, input_literal->shape(), "a");
+ builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds);
+
+ std::unique_ptr<Literal> expected = LiteralUtil::Relayout(
+ *LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal),
+ LayoutUtil::MakeLayout({3, 2, 1, 0}));
+
+ // Specify the requested output shape explicitly to ensure that this reshape
+ // actually corresponds to a two minor transpose.
+ ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ zero_error_spec_, &expected->shape());
+}
+
+XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) {
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ std::vector<int64> bounds = {1, 1, 250, 300};
+ std::vector<int64> new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]};
+ Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
+ input.Each(
+ [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.Parameter(0, input_literal->shape(), "a");
+ builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds);
+
+ std::unique_ptr<Literal> expected = LiteralUtil::Relayout(
+ *LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal),
+ LayoutUtil::MakeLayout({3, 2, 1, 0}));
+
+ // Specify the requested output shape explicitly to ensure that this reshape
+ // actually corresponds to a two minor transpose.
+ ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ zero_error_spec_, &expected->shape());
+}
+
+XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) {
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ std::vector<int64> bounds = {5, 5, 1, 10};
+ std::vector<int64> new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]};
+ Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
+ input.Each(
+ [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.Parameter(0, input_literal->shape(), "a");
+ builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds);
+
+ std::unique_ptr<Literal> expected = LiteralUtil::Relayout(
+ *LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal),
+ LayoutUtil::MakeLayout({3, 2, 1, 0}));
+
+ // Specify the requested output shape explicitly to ensure that this reshape
+ // actually corresponds to a two minor transpose.
+ ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ zero_error_spec_, &expected->shape());
+}
+
+XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) {
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ // This happens in NN-Builder MNIST.
+ std::vector<int64> bounds = {5, 5, 10, 1};
+ std::vector<int64> new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]};
+ Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
+ input.Each(
+ [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.Parameter(0, input_literal->shape(), "a");
+ builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds);
+
+ std::unique_ptr<Literal> expected = LiteralUtil::Relayout(
+ *LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal),
+ LayoutUtil::MakeLayout({3, 2, 1, 0}));
+
+ // Specify the requested output shape explicitly to ensure that this reshape
+ // actually corresponds to a two minor transpose.
+ ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ zero_error_spec_, &expected->shape());
+}
+
+XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ std::vector<int64> bounds = {3, 3, 1, 3};
+ std::vector<int64> new_bounds = {bounds[1], bounds[0], bounds[2], bounds[3]};
+ Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
+ input.Each(
+ [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({0, 1, 2, 3}));
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.Parameter(0, input_literal->shape(), "a");
+ builder.Reshape(a, /*dimensions=*/{1, 0, 2, 3}, /*new_sizes=*/new_bounds);
+
+ std::unique_ptr<Literal> expected = LiteralUtil::Relayout(
+ *LiteralTestUtil::Reshape(new_bounds, {1, 0, 2, 3}, *input_literal),
+ input_literal->shape().layout());
+
+ // Specify the requested output shape explicitly to ensure that this reshape
+ // actually corresponds to a two minor transpose.
+ ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ zero_error_spec_, &expected->shape());
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc
new file mode 100644
index 0000000000..63dd4421fa
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/reverse_test.cc
@@ -0,0 +1,173 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ReverseTest : public ClientLibraryTestBase {};
+
+// Tests the reverse operation on a scalar.
+XLA_TEST_F(ReverseTest, ReverseScalar) {
+ ComputationBuilder b(client_, TestName());
+ float input = 3.5f;
+ b.Rev(b.ConstantR0<float>(input), {});
+ ComputeAndCompareR0<float>(&b, input, {});
+}
+
+// Tests the reverse operation on a 0x0 float array on both dimensions.
+XLA_TEST_F(ReverseTest, Reverse0x0FloatArray) {
+ ComputationBuilder b(client_, TestName());
+ b.Rev(b.ConstantR2FromArray2D<float>(Array2D<float>(0, 0)), {0, 1});
+ ComputeAndCompareR2<float>(&b, Array2D<float>(0, 0), {});
+}
+
+// Tests the reverse operation on a 0x1 float array on both dimensions.
+XLA_TEST_F(ReverseTest, Reverse0x1FloatArray) {
+ ComputationBuilder b(client_, TestName());
+ b.Rev(b.ConstantR2FromArray2D<float>(Array2D<float>(0, 1)), {0, 1});
+ ComputeAndCompareR2<float>(&b, Array2D<float>(0, 1), {});
+}
+
+// Tests the reverse operation on a 1x0 float array on both dimensions.
+XLA_TEST_F(ReverseTest, Reverse1x0FloatArray) {
+ ComputationBuilder b(client_, TestName());
+ b.Rev(b.ConstantR2FromArray2D<float>(Array2D<float>(1, 0)), {0, 1});
+ ComputeAndCompareR2<float>(&b, Array2D<float>(1, 0), {});
+}
+
+// Tests the reverse operation on a 1x1 float array on both dimensions.
+XLA_TEST_F(ReverseTest, Reverse1x1FloatArray) {
+ ComputationBuilder b(client_, TestName());
+ Array2D<float> input({{3.5f}});
+ b.Rev(b.ConstantR2FromArray2D<float>(input), {0, 1});
+ ComputeAndCompareR2<float>(&b, input, {});
+}
+
+XLA_TEST_F(ReverseTest, Reverse2x0x4x3FloatArrayDim02) {
+ ComputationBuilder b(client_, TestName());
+ b.Rev(b.ConstantR4FromArray4D<float>(Array4D<float>(2, 0, 4, 3)), {0, 2});
+ ComputeAndCompareR4<float>(&b, Array4D<float>(2, 0, 4, 3), {});
+}
+
+XLA_TEST_F(ReverseTest, Reverse2x0x4x3FloatArrayDim13) {
+ ComputationBuilder b(client_, TestName());
+ b.Rev(b.ConstantR4FromArray4D<float>(Array4D<float>(2, 0, 4, 3)), {1, 3});
+ ComputeAndCompareR4<float>(&b, Array4D<float>(2, 0, 4, 3), {});
+}
+
+// Tests the reverse operation on a 4D U8 array on dimension 0 and 3.
+XLA_TEST_F(ReverseTest, Reverse4DU8ArrayOnDim23) {
+ ComputationBuilder b(client_, TestName());
+ // Input shape is U8[1x2x3x4].
+ // clang-format off
+ Array4D<uint8> input({{
+ {{1, 2, 3, 4},
+ {5, 6, 7, 8},
+ {9, 10, 11, 12}},
+ {{13, 14, 15, 16},
+ {17, 18, 19, 20},
+ {21, 22, 23, 24}},
+ }});
+ // clang-format on
+
+ b.Rev(b.ConstantR4FromArray4D<uint8>(input), {0, 3});
+
+ // clang-format off
+ Array4D<uint8> expected({{
+ {{4, 3, 2, 1},
+ {8, 7, 6, 5},
+ {12, 11, 10, 9}},
+ {{16, 15, 14, 13},
+ {20, 19, 18, 17},
+ {24, 23, 22, 21}},
+ }});
+ // clang-format on
+ ComputeAndCompareR4<uint8>(&b, expected, {});
+}
+
+// Tests the reverse operation on a 4D float array on dimension 0 and 1.
+TEST_F(ReverseTest, Reverse4DFloatArrayOnDim01) {
+ ComputationBuilder b(client_, TestName());
+ // Input shape is float[4x3x2x1].
+ // clang-format off
+ Array4D<float> input({
+ {{{1.0f}, {2.0f}},
+ {{3.0f}, {4.0f}},
+ {{5.0f}, {6.0f}}},
+ {{{7.0f}, {8.0f}},
+ {{9.0f}, {10.0f}},
+ {{11.0f}, {12.0f}}},
+ {{{13.0f}, {14.0f}},
+ {{15.0f}, {16.0f}},
+ {{17.0f}, {18.0f}}},
+ {{{19.0f}, {20.0f}},
+ {{21.0f}, {22.0f}},
+ {{23.0f}, {24.0f}}},
+ });
+ // clang-format on
+
+ b.Rev(b.ConstantR4FromArray4D<float>(input), {0, 1});
+
+ // clang-format off
+ Array4D<float> expected({
+ {{{23.0f}, {24.0f}},
+ {{21.0f}, {22.0f}},
+ {{19.0f}, {20.0f}}},
+ {{{17.0f}, {18.0f}},
+ {{15.0f}, {16.0f}},
+ {{13.0f}, {14.0f}}},
+ {{{11.0f}, {12.0f}},
+ {{9.0f}, {10.0f}},
+ {{7.0f}, {8.0f}}},
+ {{{5.0f}, {6.0f}},
+ {{3.0f}, {4.0f}},
+ {{1.0f}, {2.0f}}},
+ });
+ // clang-format on
+ ComputeAndCompareR4<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
new file mode 100644
index 0000000000..5b734c0f40
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
@@ -0,0 +1,160 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/packed_literal_reader.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/casts.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class RoundTripPackedLiteralTest : public ClientLibraryTestBase {
+ protected:
+ // Sends the literal to the server and retrieves it back.
+ std::unique_ptr<Literal> RoundTripToServer(const Literal& original) {
+ std::unique_ptr<GlobalData> data =
+ client_->TransferToServer(original).ConsumeValueOrDie();
+ return client_->Transfer(*data).ConsumeValueOrDie();
+ }
+};
+
+TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) {
+ string data(sizeof(float) * 2, 0);
+ tensorflow::gtl::MutableArraySlice<float> floats(
+ tensorflow::bit_cast<float*>(data.data()), 2);
+ floats[0] = 42.0;
+ floats[1] = 24.0;
+
+ string fname = tensorflow::testing::TmpDir() + "/RoundTripsR1F32Length2.data";
+ EXPECT_TRUE(
+ tensorflow::WriteStringToFile(tensorflow::Env::Default(), fname, data)
+ .ok());
+
+ std::unique_ptr<tensorflow::RandomAccessFile> f;
+ TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f));
+ PackedLiteralReader reader(f.release());
+ std::unique_ptr<Literal> actual =
+ reader.Read(ShapeUtil::MakeShape(F32, {2})).ConsumeValueOrDie();
+ EXPECT_TRUE(reader.IsExhausted());
+
+ EXPECT_EQ(42.0, LiteralUtil::Get<float>(*actual, {0}));
+ EXPECT_EQ(24.0, LiteralUtil::Get<float>(*actual, {1}));
+}
+
+TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) {
+ string data(sizeof(float) * 4, 0);
+ tensorflow::gtl::MutableArraySlice<float> floats(
+ tensorflow::bit_cast<float*>(data.data()), 4);
+ // With x as the minor dimension, these will become:
+ floats[0] = 42.0; // y=0,x=0
+ floats[1] = 24.0; // y=0,x=1
+ floats[2] = 64.0; // y=1,x=0
+ floats[3] = 46.0; // y=1,x=1
+
+ string fname =
+ tensorflow::testing::TmpDir() + "/RoundTripsR2F32Size2x2Dim0Minor.data";
+ EXPECT_TRUE(
+ tensorflow::WriteStringToFile(tensorflow::Env::Default(), fname, data)
+ .ok());
+
+ const Layout layout = LayoutUtil::MakeLayout({1, 0});
+
+ std::unique_ptr<tensorflow::RandomAccessFile> f;
+ TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f));
+ PackedLiteralReader reader(f.release());
+ std::unique_ptr<Literal> actual =
+ reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
+ .ConsumeValueOrDie();
+ EXPECT_TRUE(reader.IsExhausted());
+
+ EXPECT_EQ(42.0f, LiteralUtil::Get<float>(*actual, {0, 0}));
+ EXPECT_EQ(24.0f, LiteralUtil::Get<float>(*actual, {0, 1}));
+ EXPECT_EQ(64.0f, LiteralUtil::Get<float>(*actual, {1, 0}));
+ EXPECT_EQ(46.0f, LiteralUtil::Get<float>(*actual, {1, 1}));
+
+ std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
+ LiteralTestUtil::ExpectEqual(*round_tripped, *actual);
+}
+
+TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
+ string data(sizeof(float) * 4, 0);
+ tensorflow::gtl::MutableArraySlice<float> floats(
+ tensorflow::bit_cast<float*>(data.data()), 4);
+ // With y as the minor dimension, these will become:
+ floats[0] = 42.0; // y=0,x=0
+ floats[1] = 24.0; // y=1,x=0
+ floats[2] = 64.0; // y=0,x=1
+ floats[3] = 46.0; // y=1,x=1
+
+ string fname =
+ tensorflow::testing::TmpDir() + "/RoundTripsR2F32Size2x2Dim1Minor.data";
+ EXPECT_TRUE(
+ tensorflow::WriteStringToFile(tensorflow::Env::Default(), fname, data)
+ .ok());
+
+ const Layout layout = LayoutUtil::MakeLayout({0, 1});
+
+ std::unique_ptr<tensorflow::RandomAccessFile> f;
+ TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f));
+ PackedLiteralReader reader(f.release());
+ std::unique_ptr<Literal> actual =
+ reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
+ .ConsumeValueOrDie();
+ EXPECT_TRUE(reader.IsExhausted());
+
+ EXPECT_EQ(42.0f, LiteralUtil::Get<float>(*actual, {0, 0}));
+ EXPECT_EQ(24.0f, LiteralUtil::Get<float>(*actual, {1, 0}));
+ EXPECT_EQ(64.0f, LiteralUtil::Get<float>(*actual, {0, 1}));
+ EXPECT_EQ(46.0f, LiteralUtil::Get<float>(*actual, {1, 1}));
+
+ std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
+ LiteralTestUtil::ExpectEqual(*round_tripped, *actual);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
new file mode 100644
index 0000000000..04a8bab0eb
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
@@ -0,0 +1,164 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests transferring literals of various shapes and values in and out of the
+// XLA service.
+
+#include <memory>
+#include <numeric>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class RoundTripTransferTest : public ClientLibraryTestBase {
+ protected:
+ void RoundTripTest(const Literal& original) {
+ std::unique_ptr<GlobalData> data =
+ client_->TransferToServer(original).ConsumeValueOrDie();
+ std::unique_ptr<Literal> result =
+ client_->Transfer(*data).ConsumeValueOrDie();
+ LiteralTestUtil::ExpectEqual(original, *result);
+ }
+};
+
+TEST_F(RoundTripTransferTest, R0S32) {
+ RoundTripTest(*LiteralUtil::CreateR0<int32>(42));
+}
+
+TEST_F(RoundTripTransferTest, R0F32) {
+ RoundTripTest(*LiteralUtil::CreateR0<float>(42.0));
+}
+
+TEST_F(RoundTripTransferTest, R1F32_Len0) {
+ RoundTripTest(*LiteralUtil::CreateR1<float>({}));
+}
+
+TEST_F(RoundTripTransferTest, R1F32_Len2) {
+ RoundTripTest(*LiteralUtil::CreateR1<float>({42.0, 64.0}));
+}
+
+TEST_F(RoundTripTransferTest, R1F32_Len256) {
+ std::vector<float> values(256);
+ std::iota(values.begin(), values.end(), 1.0);
+ RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+}
+
+TEST_F(RoundTripTransferTest, R1F32_Len1024) {
+ std::vector<float> values(1024);
+ std::iota(values.begin(), values.end(), 1.0);
+ RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+}
+
+TEST_F(RoundTripTransferTest, R1F32_Len1025) {
+ std::vector<float> values(1025);
+ std::iota(values.begin(), values.end(), 1.0);
+ RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+}
+
+TEST_F(RoundTripTransferTest, R1F32_Len4096) {
+ std::vector<float> values(4096);
+ std::iota(values.begin(), values.end(), 1.0);
+ RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+}
+
+TEST_F(RoundTripTransferTest, R2F32_Len10x0) {
+ RoundTripTest(
+ *LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(10, 0)));
+}
+
+TEST_F(RoundTripTransferTest, R2F32_Len2x2) {
+ RoundTripTest(*LiteralUtil::CreateR2<float>({{42.0, 64.0}, {77.0, 88.0}}));
+}
+
+TEST_F(RoundTripTransferTest, R3F32) {
+ RoundTripTest(
+ *LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
+ {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}));
+}
+
+TEST_F(RoundTripTransferTest, R4F32) {
+ RoundTripTest(*LiteralUtil::CreateR4<float>({{
+ {{10, 11, 12, 13}, {14, 15, 16, 17}},
+ {{18, 19, 20, 21}, {22, 23, 24, 25}},
+ {{26, 27, 28, 29}, {30, 31, 32, 33}},
+ }}));
+}
+
+TEST_F(RoundTripTransferTest, EmptyTuple) {
+ RoundTripTest(*LiteralUtil::MakeTuple({}));
+}
+
+TEST_F(RoundTripTransferTest, TupleOfR1F32) {
+ RoundTripTest(
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(),
+ LiteralUtil::CreateR1<float>({3, 4}).get()}));
+}
+
+TEST_F(RoundTripTransferTest, TupleOfR1F32_Len0_Len2) {
+ RoundTripTest(
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({}).get(),
+ LiteralUtil::CreateR1<float>({3, 4}).get()}));
+}
+
+TEST_F(RoundTripTransferTest, TupleOfR0F32AndR1S32) {
+ RoundTripTest(
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(1.0).get(),
+ LiteralUtil::CreateR1<int>({2, 3}).get()}));
+}
+
+// Below two tests are added to identify the cost of large data transfers.
+TEST_F(RoundTripTransferTest, R2F32_Large) {
+ RoundTripTest(*LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512));
+}
+
+TEST_F(RoundTripTransferTest, R4F32_Large) {
+ Array4D<float> array4d(2, 2, 256, 256);
+ array4d.FillWithMultiples(1.0f);
+ RoundTripTest(*LiteralUtil::CreateR4FromArray4D<float>(array4d));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
new file mode 100644
index 0000000000..bd9cae4d1d
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
@@ -0,0 +1,630 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cmath>
+#include <limits>
+#include <memory>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ScalarComputationsTest : public ClientLibraryTestBase {
+ public:
+ ErrorSpec error_spec_{0.0001};
+
+ protected:
+ // A template for building and running a binary comparison test.
+ template <typename NativeT>
+ void TestCompare(NativeT lhs, NativeT rhs, bool expected,
+ ComputationDataHandle (ComputationBuilder::*op)(
+ const ComputationDataHandle&,
+ const ComputationDataHandle&,
+ tensorflow::gtl::ArraySlice<int64>)) {
+ ComputationBuilder builder(client_, TestName());
+ ComputationDataHandle lhs_op = builder.ConstantR0<NativeT>(lhs);
+ ComputationDataHandle rhs_op = builder.ConstantR0<NativeT>(rhs);
+ ComputationDataHandle result = (builder.*op)(lhs_op, rhs_op, {});
+ ComputeAndCompareR0<bool>(&builder, expected, {});
+ }
+
+ template <typename NativeT>
+ void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected,
+ ComputationDataHandle (ComputationBuilder::*op)(
+ const ComputationDataHandle&,
+ const ComputationDataHandle&,
+ tensorflow::gtl::ArraySlice<int64>)) {
+ ComputationBuilder builder(client_, TestName());
+ ComputationDataHandle lhs_op = builder.ConstantR0<NativeT>(lhs);
+ ComputationDataHandle rhs_op = builder.ConstantR0<NativeT>(rhs);
+ ComputationDataHandle result = (builder.*op)(lhs_op, rhs_op, {});
+ ComputeAndCompareR0<NativeT>(&builder, expected, {});
+ }
+};
+
+TEST_F(ScalarComputationsTest, NegateScalarF32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Neg(builder.ConstantR0<float>(2.1f));
+
+ ComputeAndCompareR0<float>(&builder, -2.1f, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, NegateScalarS32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Neg(builder.ConstantR0<int32>(2));
+
+ ComputeAndCompareR0<int32>(&builder, -2, {});
+}
+
+TEST_F(ScalarComputationsTest, AddTwoScalarsF32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Add(builder.ConstantR0<float>(2.1f), builder.ConstantR0<float>(5.5f));
+
+ ComputeAndCompareR0<float>(&builder, 7.6f, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, AddTwoScalarsS32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Add(builder.ConstantR0<int32>(2), builder.ConstantR0<int32>(5));
+
+ ComputeAndCompareR0<int32>(&builder, 7, {});
+}
+
+TEST_F(ScalarComputationsTest, AddTwoScalarsU32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Add(builder.ConstantR0<uint32>(35), builder.ConstantR0<uint32>(57));
+
+ ComputeAndCompareR0<uint32>(&builder, 92, {});
+}
+
+XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU8) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Add(builder.ConstantR0<uint8>(35), builder.ConstantR0<uint8>(57));
+
+ ComputeAndCompareR0<uint8>(&builder, 92, {});
+}
+
+XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU64) {
+ ComputationBuilder builder(client_, TestName());
+ const uint64 a = static_cast<uint64>(1) << 63;
+ const uint64 b = a + 1;
+ builder.Add(builder.ConstantR0<uint64>(a), builder.ConstantR0<uint64>(b));
+
+ ComputeAndCompareR0<uint64>(&builder, a + b, {});
+}
+
+XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS64) {
+ ComputationBuilder builder(client_, TestName());
+ const int64 a = static_cast<int64>(1) << 62;
+ const int64 b = a + 1;
+ builder.Add(builder.ConstantR0<int64>(a), builder.ConstantR0<int64>(b));
+
+ ComputeAndCompareR0<int64>(&builder, a + b, {});
+}
+
+XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF64) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Add(builder.ConstantR0<double>(0.25),
+ builder.ConstantR0<double>(3.5));
+
+ ComputeAndCompareR0<double>(&builder, 3.75, {});
+}
+
+TEST_F(ScalarComputationsTest, SubtractTwoScalarsF32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Sub(builder.ConstantR0<float>(2.1f), builder.ConstantR0<float>(5.5f));
+
+ ComputeAndCompareR0<float>(&builder, -3.4f, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Sub(builder.ConstantR0<int32>(2), builder.ConstantR0<int32>(5));
+
+ ComputeAndCompareR0<int32>(&builder, -3, {});
+}
+
+TEST_F(ScalarComputationsTest, MulThreeScalarsF32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Mul(builder.Mul(builder.ConstantR0<float>(2.1f),
+ builder.ConstantR0<float>(5.5f)),
+ builder.ConstantR0<float>(0.5f));
+
+ ComputeAndCompareR0<float>(&builder, 5.775f, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, MulTwoScalarsS32) {
+ std::vector<int32> data = {0,
+ 1,
+ -1,
+ 1234,
+ 0x1a243514,
+ std::numeric_limits<int32>::max(),
+ std::numeric_limits<int32>::min()};
+
+ for (int32 x : data) {
+ for (int32 y : data) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Mul(builder.ConstantR0<int32>(x), builder.ConstantR0<int32>(y));
+
+ // Signed integer overflow is undefined behavior in C++. Convert the input
+ // integers to unsigned, perform the multiplication unsigned, and convert
+ // back.
+ int32 expected = static_cast<uint32>(x) * static_cast<uint32>(y);
+
+ ComputeAndCompareR0<int32>(&builder, expected, {});
+ }
+ }
+}
+
+TEST_F(ScalarComputationsTest, MulTwoScalarsU32) {
+ std::vector<uint32> data = {0, 1, 0xDEADBEEF, 1234,
+ 0x1a243514, 0xFFFFFFFF, 0x80808080};
+
+ for (uint32 x : data) {
+ for (uint32 y : data) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Mul(builder.ConstantR0<uint32>(x), builder.ConstantR0<uint32>(y));
+
+ uint32 expected = x * y;
+ ComputeAndCompareR0<uint32>(&builder, expected, {});
+ }
+ }
+}
+
+TEST_F(ScalarComputationsTest, MulThreeScalarsS32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Mul(
+ builder.Mul(builder.ConstantR0<int32>(2), builder.ConstantR0<int32>(5)),
+ builder.ConstantR0<int32>(1));
+
+ ComputeAndCompareR0<int32>(&builder, 10, {});
+}
+
+TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) {
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR0<float>(2.1f);
+ std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR0<float>(5.5f);
+ std::unique_ptr<Literal> c_literal = LiteralUtil::CreateR0<float>(0.5f);
+
+ std::unique_ptr<GlobalData> a_data =
+ client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> b_data =
+ client_->TransferToServer(*b_literal).ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> c_data =
+ client_->TransferToServer(*c_literal).ConsumeValueOrDie();
+
+ ComputationDataHandle a = builder.Parameter(0, a_literal->shape(), "a");
+ ComputationDataHandle b = builder.Parameter(1, b_literal->shape(), "b");
+ ComputationDataHandle c = builder.Parameter(2, c_literal->shape(), "c");
+ builder.Mul(builder.Mul(a, b), c);
+
+ ComputeAndCompareR0<float>(&builder, 5.775f,
+ {a_data.get(), b_data.get(), c_data.get()},
+ error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, DivideTwoScalarsF32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Div(builder.ConstantR0<float>(5.0f), builder.ConstantR0<float>(2.5f));
+
+ ComputeAndCompareR0<float>(&builder, 2.0f, {}, error_spec_);
+}
+
+XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsF32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Rem(builder.ConstantR0<float>(2.5f), builder.ConstantR0<float>(5.0f));
+
+ ComputeAndCompareR0<float>(&builder, 2.5f, {}, error_spec_);
+}
+
+XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsS32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Div(builder.ConstantR0<int32>(-5), builder.ConstantR0<int32>(2));
+
+ ComputeAndCompareR0<int32>(&builder, -2, {});
+}
+
+TEST_F(ScalarComputationsTest, RemainderTwoScalarsNegativeResultS32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Rem(builder.ConstantR0<int32>(-5), builder.ConstantR0<int32>(2));
+
+ ComputeAndCompareR0<int32>(&builder, -1, {});
+}
+
+TEST_F(ScalarComputationsTest, RemainderTwoScalarsIntMinS32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Rem(builder.ConstantR0<int32>(INT_MIN),
+ builder.ConstantR0<int32>(7919));
+
+ ComputeAndCompareR0<int32>(&builder, -1309, {});
+}
+
+TEST_F(ScalarComputationsTest, RemainderTwoScalarsIntMinVsIntMaxS32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Rem(builder.ConstantR0<int32>(INT_MIN),
+ builder.ConstantR0<int32>(INT_MAX));
+
+ ComputeAndCompareR0<int32>(&builder, -1, {});
+}
+
+TEST_F(ScalarComputationsTest, RemainderTwoScalarsPositiveResultS32) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x");
+ builder.Rem(x, builder.ConstantR0<int32>(80000));
+
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<int32>(87919);
+ TF_ASSIGN_OR_ASSERT_OK(auto input_data, client_->TransferToServer(*literal));
+ ComputeAndCompareR0<int32>(&builder, 7919, {input_data.get()});
+}
+
+XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsU32) {
+ ComputationBuilder builder(client_, TestName());
+ // This verifies 0xFFFFFFFE / 2 = 0x7FFFFFFF. If XLA incorrectly treated U32
+ // as S32, it would output -2 / 2 = -1 (0xFFFFFFFF).
+ builder.Div(builder.ConstantR0<uint32>(0xFFFFFFFE),
+ builder.ConstantR0<uint32>(2));
+
+ ComputeAndCompareR0<uint32>(&builder, 0x7FFFFFFF, {});
+}
+
+TEST_F(ScalarComputationsTest, LogicalAnd) {
+ for (bool x : {false, true}) {
+ for (bool y : {false, true}) {
+ ComputationBuilder builder(client_, TestName());
+ builder.LogicalAnd(builder.ConstantR0<bool>(x),
+ builder.ConstantR0<bool>(y));
+
+ ComputeAndCompareR0<bool>(&builder, x && y, {});
+ }
+ }
+}
+
+TEST_F(ScalarComputationsTest, LogicalOr) {
+ for (bool x : {false, true}) {
+ for (bool y : {false, true}) {
+ ComputationBuilder builder(client_, TestName());
+ builder.LogicalOr(builder.ConstantR0<bool>(x),
+ builder.ConstantR0<bool>(y));
+
+ ComputeAndCompareR0<bool>(&builder, x || y, {});
+ }
+ }
+}
+
+TEST_F(ScalarComputationsTest, LogicalNot) {
+ for (bool x : {false, true}) {
+ ComputationBuilder builder(client_, TestName());
+ builder.LogicalNot(builder.ConstantR0<bool>(x));
+
+ ComputeAndCompareR0<bool>(&builder, !x, {});
+ }
+}
+
+TEST_F(ScalarComputationsTest, SelectScalarTrue) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Select(builder.ConstantR0<bool>(true), // The predicate.
+ builder.ConstantR0<float>(123.0f), // The value on true.
+ builder.ConstantR0<float>(42.0f)); // The value on false.
+
+ ComputeAndCompareR0<float>(&builder, 123.0f, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, SelectScalarFalse) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Select(builder.ConstantR0<bool>(false), // The predicate.
+ builder.ConstantR0<float>(123.0f), // The value on true.
+ builder.ConstantR0<float>(42.0f)); // The value on false.
+
+ ComputeAndCompareR0<float>(&builder, 42.0f, {}, error_spec_);
+}
+
+// This test is an explicit version of what is happening in the following
+// templatized comparison tests.
+TEST_F(ScalarComputationsTest, CompareGtScalar) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Gt(builder.ConstantR0<float>(2.0f), builder.ConstantR0<float>(1.0f));
+
+ ComputeAndCompareR0<bool>(&builder, true, {});
+}
+
+// S32 comparisons.
+TEST_F(ScalarComputationsTest, CompareEqS32Greater) {
+ TestCompare<int32>(2, 1, false, &ComputationBuilder::Eq);
+}
+TEST_F(ScalarComputationsTest, CompareEqS32Equal) {
+ TestCompare<int32>(3, 3, true, &ComputationBuilder::Eq);
+}
+
+TEST_F(ScalarComputationsTest, CompareNeS32) {
+ TestCompare<int32>(2, 1, true, &ComputationBuilder::Ne);
+}
+
+TEST_F(ScalarComputationsTest, CompareGeS32) {
+ TestCompare<int32>(2, 1, true, &ComputationBuilder::Ge);
+}
+
+TEST_F(ScalarComputationsTest, CompareGtS32) {
+ TestCompare<int32>(1, 5, false, &ComputationBuilder::Gt);
+}
+
+TEST_F(ScalarComputationsTest, CompareLeS32) {
+ TestCompare<int32>(2, 1, false, &ComputationBuilder::Le);
+}
+
+TEST_F(ScalarComputationsTest, CompareLtS32) {
+ TestCompare<int32>(9, 7, false, &ComputationBuilder::Lt);
+ TestCompare<int32>(std::numeric_limits<int32>::min(),
+ std::numeric_limits<int32>::max(), true,
+ &ComputationBuilder::Lt);
+}
+
+// U32 comparisons.
+TEST_F(ScalarComputationsTest, CompareEqU32False) {
+ TestCompare<uint32>(2, 1, false, &ComputationBuilder::Eq);
+}
+
+TEST_F(ScalarComputationsTest, CompareNeU32) {
+ TestCompare<uint32>(2, 1, true, &ComputationBuilder::Ne);
+}
+
+TEST_F(ScalarComputationsTest, CompareGeU32Greater) {
+ TestCompare<uint32>(2, 1, true, &ComputationBuilder::Ge);
+}
+
+TEST_F(ScalarComputationsTest, CompareGeU32Equal) {
+ TestCompare<uint32>(3, 3, true, &ComputationBuilder::Ge);
+}
+
+TEST_F(ScalarComputationsTest, CompareGtU32) {
+ TestCompare<uint32>(1, 5, false, &ComputationBuilder::Gt);
+ TestCompare<uint32>(5, 5, false, &ComputationBuilder::Gt);
+ TestCompare<uint32>(5, 1, true, &ComputationBuilder::Gt);
+}
+
+TEST_F(ScalarComputationsTest, CompareLeU32) {
+ TestCompare<uint32>(2, 1, false, &ComputationBuilder::Le);
+}
+
+TEST_F(ScalarComputationsTest, CompareLtU32) {
+ TestCompare<uint32>(9, 7, false, &ComputationBuilder::Lt);
+ TestCompare<uint32>(0, std::numeric_limits<uint32>::max(), true,
+ &ComputationBuilder::Lt);
+}
+
+// F32 comparisons.
+TEST_F(ScalarComputationsTest, CompareEqF32False) {
+ TestCompare<float>(2.0, 1.3, false, &ComputationBuilder::Eq);
+}
+
+TEST_F(ScalarComputationsTest, CompareNeF32) {
+ TestCompare<float>(2.0, 1.3, true, &ComputationBuilder::Ne);
+}
+
+TEST_F(ScalarComputationsTest, CompareGeF32Greater) {
+ TestCompare<float>(2.0, 1.9, true, &ComputationBuilder::Ge);
+}
+TEST_F(ScalarComputationsTest, CompareGeF32Equal) {
+ TestCompare<float>(3.5, 3.5, true, &ComputationBuilder::Ge);
+}
+
+TEST_F(ScalarComputationsTest, CompareGtF32) {
+ TestCompare<float>(1.0, 5.2, false, &ComputationBuilder::Gt);
+}
+
+TEST_F(ScalarComputationsTest, CompareLeF32) {
+ TestCompare<float>(2.0, 1.2, false, &ComputationBuilder::Le);
+}
+
+TEST_F(ScalarComputationsTest, CompareLtF32) {
+ TestCompare<float>(9.0, 7.2, false, &ComputationBuilder::Lt);
+}
+
+// F32 comparisons with exceptional values. The test names encode the
+// left/right operands at the end, and use Minf and Mzero for -inf and -0.0.
+TEST_F(ScalarComputationsTest, CompareLtF32MinfMzero) {
+ TestCompare<float>(-INFINITY, -0.0, true, &ComputationBuilder::Lt);
+}
+TEST_F(ScalarComputationsTest, CompareLtF32MzeroZero) {
+ // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754.
+ TestCompare<float>(-0.0, 0.0, false, &ComputationBuilder::Lt);
+}
+TEST_F(ScalarComputationsTest, CompareLtF32ZeroInf) {
+ TestCompare<float>(0.0, INFINITY, true, &ComputationBuilder::Lt);
+}
+
+TEST_F(ScalarComputationsTest, CompareGeF32MinfMzero) {
+ TestCompare<float>(-INFINITY, -0.0, false, &ComputationBuilder::Ge);
+}
+TEST_F(ScalarComputationsTest, CompareGeF32MzeroZero) {
+ // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754.
+ TestCompare<float>(-0.0, 0.0, true, &ComputationBuilder::Ge);
+}
+TEST_F(ScalarComputationsTest, CompareGeF32ZeroInf) {
+ TestCompare<float>(0.0, INFINITY, false, &ComputationBuilder::Ge);
+}
+
+TEST_F(ScalarComputationsTest, ExpScalar) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Exp(builder.ConstantR0<float>(2.0f));
+
+ ComputeAndCompareR0<float>(&builder, 7.3890562, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, LogScalar) {
+ ComputationBuilder builder(client_, "log");
+ builder.Log(builder.ConstantR0<float>(2.0f));
+
+ ComputeAndCompareR0<float>(&builder, 0.6931471, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, TanhScalar) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Tanh(builder.ConstantR0<float>(2.0f));
+
+ ComputeAndCompareR0<float>(&builder, 0.96402758, {}, error_spec_);
+}
+
+XLA_TEST_F(ScalarComputationsTest, TanhDoubleScalar) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Tanh(builder.ConstantR0<double>(2.0));
+
+ ComputeAndCompareR0<double>(&builder, 0.96402758, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, PowScalar) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Pow(builder.ConstantR0<float>(2.0f), builder.ConstantR0<float>(3.0f));
+
+ ComputeAndCompareR0<float>(&builder, 8.0, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, ClampScalarHigh) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Clamp(builder.ConstantR0<float>(2.0f), // The lower bound.
+ builder.ConstantR0<float>(5.0f), // The operand to be clamped.
+ builder.ConstantR0<float>(3.0f)); // The upper bound.
+
+ ComputeAndCompareR0<float>(&builder, 3.0, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, ClampScalarMiddle) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Clamp(builder.ConstantR0<float>(2.0f), // The lower bound.
+ builder.ConstantR0<float>(2.5f), // The operand to be clamped.
+ builder.ConstantR0<float>(3.0f)); // The upper bound.
+
+ ComputeAndCompareR0<float>(&builder, 2.5, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, ClampScalarLow) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Clamp(builder.ConstantR0<float>(2.0f), // The lower bound.
+ builder.ConstantR0<float>(-5.0f), // The operand to be clamped.
+ builder.ConstantR0<float>(3.0f)); // The upper bound.
+
+ ComputeAndCompareR0<float>(&builder, 2.0, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, MinS32Above) {
+ TestMinMax<int32>(10, 3, 3, &ComputationBuilder::Min);
+}
+
+TEST_F(ScalarComputationsTest, MinS32Below) {
+ TestMinMax<int32>(-100, 3, -100, &ComputationBuilder::Min);
+}
+
+TEST_F(ScalarComputationsTest, MaxS32Above) {
+ TestMinMax<int32>(10, 3, 10, &ComputationBuilder::Max);
+}
+
+TEST_F(ScalarComputationsTest, MaxS32Below) {
+ TestMinMax<int32>(-100, 3, 3, &ComputationBuilder::Max);
+}
+
+TEST_F(ScalarComputationsTest, MinU32Above) {
+ const uint32 large = std::numeric_limits<int32>::max();
+ TestMinMax<uint32>(large, 3, 3, &ComputationBuilder::Min);
+}
+
+TEST_F(ScalarComputationsTest, MinU32Below) {
+ TestMinMax<uint32>(0, 5, 0, &ComputationBuilder::Min);
+}
+
+TEST_F(ScalarComputationsTest, MaxU32Above) {
+ const uint32 large = std::numeric_limits<int32>::max();
+ TestMinMax<uint32>(large, 3, large, &ComputationBuilder::Max);
+}
+
+TEST_F(ScalarComputationsTest, MaxU32Below) {
+ TestMinMax<uint32>(0, 5, 5, &ComputationBuilder::Max);
+}
+
+TEST_F(ScalarComputationsTest, MinF32Above) {
+ TestMinMax<float>(10.1f, 3.1f, 3.1f, &ComputationBuilder::Min);
+}
+
+TEST_F(ScalarComputationsTest, MinF32Below) {
+ TestMinMax<float>(-100.1f, 3.1f, -100.1f, &ComputationBuilder::Min);
+}
+
+TEST_F(ScalarComputationsTest, MaxF32Above) {
+ TestMinMax<float>(10.1f, 3.1f, 10.1f, &ComputationBuilder::Max);
+}
+
+TEST_F(ScalarComputationsTest, MaxF32Below) {
+ TestMinMax<float>(-100.1f, 3.1f, 3.1f, &ComputationBuilder::Max);
+}
+
+TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) {
+ // Compute the expression (1 * (3 - 1) * (7 + 0) - 4) / 20.
+ ComputationBuilder b(client_, TestName());
+ b.Div(
+ b.Sub(b.Mul(b.ConstantR0<float>(1),
+ b.Mul(b.Sub(b.ConstantR0<float>(3), b.ConstantR0<float>(1)),
+ b.Add(b.ConstantR0<float>(7), b.ConstantR0<float>(0)))),
+ b.ConstantR0<float>(4)),
+ b.ConstantR0<float>(20));
+
+ ComputeAndCompareR0<float>(&b, 0.5, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) {
+ // Compute the expression 1 * (3 - 1) * (7 + 0) - 4.
+ ComputationBuilder b(client_, TestName());
+ b.Sub(b.Mul(b.ConstantR0<int32>(1),
+ b.Mul(b.Sub(b.ConstantR0<int32>(3), b.ConstantR0<int32>(1)),
+ b.Add(b.ConstantR0<int32>(7), b.ConstantR0<int32>(0)))),
+ b.ConstantR0<int32>(4));
+
+ ComputeAndCompareR0<int32>(&b, 10, {});
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendLlvmBackendFlags(&flag_list);
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
new file mode 100644
index 0000000000..fb1effc8c4
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
@@ -0,0 +1,395 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests the select-and-scatter XLA operation.
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/padding.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class SelectAndScatterTest : public ClientLibraryTestBase {
+ public:
+ SelectAndScatterTest() : builder_(client_, TestName()) {
+ // Create S32 GE and ADD computations for select and scatter respectively.
+ ge_s32_ = CreateScalarGeComputation(S32, &builder_);
+ add_s32_ = CreateScalarAddComputation(S32, &builder_);
+ ge_f32_ = CreateScalarGeComputation(F32, &builder_);
+ add_f32_ = CreateScalarAddComputation(F32, &builder_);
+ max_f32_ = CreateScalarMaxComputation(F32, &builder_);
+ min_f32_ = CreateScalarMinComputation(F32, &builder_);
+ }
+
+ ComputationBuilder builder_;
+ Computation ge_s32_;
+ Computation add_s32_;
+ Computation ge_f32_;
+ Computation add_f32_;
+ Computation max_f32_;
+ Computation min_f32_;
+};
+
+// Test for F32 1D array, with a zero-element input.
+XLA_TEST_F(SelectAndScatterTest, R1S0F32) {
+ const auto operand = builder_.ConstantR1<float>({});
+ const auto source = builder_.ConstantR1<float>({});
+ builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3},
+ /*window_strides=*/{3}, Padding::kValid, source,
+ builder_.ConstantR0<float>(0.0f), add_f32_);
+ ComputeAndCompareR1<float>(&builder_, {}, {}, ErrorSpec(1e-7));
+}
+
+// Test for F32 1D array, when windows do not overlap.
+XLA_TEST_F(SelectAndScatterTest, R1F32) {
+ const auto operand =
+ builder_.ConstantR1<float>({1.f, 9.f, 3.f, 7.f, 5.f, 6.f});
+ const auto source = builder_.ConstantR1<float>({34.f, 42.f});
+ const std::vector<float> expected = {0.f, 34.f, 0.f, 42.f, 0.f, 0.f};
+ builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3},
+ /*window_strides=*/{3}, Padding::kValid, source,
+ builder_.ConstantR0<float>(0.0f), add_f32_);
+ ComputeAndCompareR1<float>(&builder_, expected, {}, ErrorSpec(1e-7));
+}
+
+// Test for S32 1D array, when windows do not overlap and the init value is 1.
+XLA_TEST_F(SelectAndScatterTest, R1S32) {
+ const auto operand = builder_.ConstantR1<int32>({-1, 0, 6, 4, -4, 10});
+ const auto source = builder_.ConstantR1<int32>({-10, 20});
+ const std::vector<int32> expected = {1, 1, -9, 1, 1, 21};
+ builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3},
+ /*window_strides=*/{3}, Padding::kValid, source,
+ builder_.ConstantR0<int32>(1), add_s32_);
+ ComputeAndCompareR1<int32>(&builder_, expected, {});
+}
+
+// Test for S32 1D array, when windows overlap with each other.
+XLA_TEST_F(SelectAndScatterTest, R1S32OverlappingWindow) {
+ const auto operand = builder_.ConstantR1<int32>({1, 9, 3, 7, 5, 6});
+ const auto source = builder_.ConstantR1<int32>({34, 42, 53, 19});
+ const std::vector<int32> expected = {0, 76, 0, 72, 0, 0};
+ builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3},
+ /*window_strides=*/{1}, Padding::kValid, source,
+ builder_.ConstantR0<int32>(0), add_s32_);
+ ComputeAndCompareR1<int32>(&builder_, expected, {});
+}
+
+// Test for S32 2D array, when windows do not overlap.
+XLA_TEST_F(SelectAndScatterTest, R2S32) {
+ const auto operand =
+ builder_.ConstantR2<int32>({{7, 2, 5, 3, 10, 2}, {3, 8, 9, 3, 4, 2}});
+ const auto source = builder_.ConstantR2<int32>({{2, 6}});
+ Array2D<int32> expected({{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}});
+ builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3},
+ /*window_strides=*/{2, 3}, Padding::kValid, source,
+ builder_.ConstantR0<int32>(0), add_s32_);
+ ComputeAndCompareR2<int32>(&builder_, expected, {});
+}
+
+// Similar to SelectAndScatterTest.R2S32 but the input is transposed.
+XLA_TEST_F(SelectAndScatterTest, ReshapeR2S32) {
+ const auto operand = builder_.ConstantR2<int32>(
+ {{7, 3}, {2, 8}, {5, 9}, {3, 3}, {10, 4}, {2, 2}});
+ const auto reshape =
+ builder_.Reshape(operand, /*dimensions=*/{1, 0}, /*new_sizes=*/{2, 6});
+ const auto source = builder_.ConstantR2<int32>({{2, 6}});
+ Array2D<int32> expected({{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}});
+ builder_.SelectAndScatter(reshape, ge_s32_, /*window_dimensions=*/{2, 3},
+ /*window_strides=*/{2, 3}, Padding::kValid, source,
+ builder_.ConstantR0<int32>(0), add_s32_);
+ ComputeAndCompareR2<int32>(&builder_, expected, {});
+}
+
+// Test for S32 2D array, when windows overlap with each other.
+XLA_TEST_F(SelectAndScatterTest, R2S32OverlappingWindow) {
+ const auto operand =
+ builder_.ConstantR2<int32>({{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}});
+ const auto source = builder_.ConstantR2<int32>({{2, 6, 4}});
+ Array2D<int32> expected({{0, 0, 0, 0, 0}, {0, 0, 12, 0, 0}});
+ builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3},
+ /*window_strides=*/{1, 1}, Padding::kValid, source,
+ builder_.ConstantR0<int32>(0), add_s32_);
+ ComputeAndCompareR2<int32>(&builder_, expected, {});
+}
+
+// Test for S32 2D array, when the padding is Padding::kSAME.
+XLA_TEST_F(SelectAndScatterTest, R2S32SamePadding) {
+ const auto operand =
+ builder_.ConstantR2<int32>({{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}});
+ const auto source = builder_.ConstantR2<int32>({{2, 6, 4}});
+ Array2D<int32> expected({{0, 0, 0, 0, 4}, {0, 2, 6, 0, 0}});
+ builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2},
+ /*window_strides=*/{2, 2}, Padding::kSame, source,
+ builder_.ConstantR0<int32>(0), add_s32_);
+ ComputeAndCompareR2<int32>(&builder_, expected, {});
+}
+
+// Test for S32 2D array, when the padding is Padding::kSAME and windows overlap
+// with each other.
+XLA_TEST_F(SelectAndScatterTest, R2S32SamePaddingOverlappingWindow) {
+ const auto operand =
+ builder_.ConstantR2<int32>({{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}});
+ const auto source =
+ builder_.ConstantR2<int32>({{2, 6, 4, 7, 1}, {3, 5, 8, 9, 10}});
+ Array2D<int32> expected({{0, 0, 0, 0, 8}, {0, 5, 23, 0, 19}});
+ builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2},
+ /*window_strides=*/{1, 1}, Padding::kSame, source,
+ builder_.ConstantR0<int32>(0), add_s32_);
+ ComputeAndCompareR2<int32>(&builder_, expected, {});
+}
+
+XLA_TEST_F(SelectAndScatterTest, R2F32OverlappingR2Source) {
+ const auto operand = builder_.ConstantR2<float>(
+ {{1.5f, 2.5f, 1.5f}, {3.5f, 1.5f, 3.5f}, {4.5f, 2.5f, 4.5f}});
+ const auto source = builder_.ConstantR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}});
+ Array2D<float> expected(
+ {{0.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 2.0f}, {3.0f, 0.0f, 4.0f}});
+ builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{2, 2},
+ /*window_strides=*/{1, 1}, Padding::kValid, source,
+ builder_.ConstantR0<float>(0.0f), add_f32_);
+ ComputeAndCompareR2<float>(&builder_, expected, {}, ErrorSpec(1e-7));
+}
+
+TEST_F(SelectAndScatterTest, R4F32Valid) {
+ Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 2.0f},
+ {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f},
+ {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f},
+ {0.0f, 6.0f, 2.0f, 7.0f, 2.00f, 8.0f}};
+ Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
+ Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 6.0f, 0.0f},
+ {0.0f, 0.0f, 2.0f, 0.0f, 0.0f, 0.0f},
+ {0.0f, 0.0f, 3.0f, 0.0f, 0.0f, 0.0f},
+ {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f}};
+ Array4D<float> o(4, 6, 15, 220);
+ o.FillWithPZ(pzo);
+ auto operand = builder_.ConstantR4FromArray4D(o);
+ Array4D<float> e(4, 6, 15, 220);
+ e.FillWithPZ(pze);
+ Array4D<float> s(2, 2, 15, 220);
+ s.FillWithPZ(pzs);
+ auto source = builder_.ConstantR4FromArray4D(s);
+ s.FillWithPZ(pzs);
+ builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1},
+ Padding::kValid, source,
+ builder_.ConstantR0<float>(0.0f), add_f32_);
+ ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
+}
+
+TEST_F(SelectAndScatterTest, R4F32Overlap) {
+ Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 8.0f},
+ {3.0f, 8.0f, 9.0f, 3.0f, 4.0f},
+ {1.0f, 5.0f, 7.0f, 5.0f, 6.0f},
+ {0.0f, 6.0f, 2.0f, 10.0f, 2.0f}};
+ Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
+ Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 0.0f},
+ {0.0f, 0.0f, 8.0f, 0.0f, 0.0f},
+ {0.0f, 0.0f, 3.0f, 0.0f, 0.0f},
+ {0.0f, 0.0f, 0.0f, 1.0f, 0.0f}};
+ Array4D<float> o(4, 5, 17, 128);
+ o.FillWithPZ(pzo);
+ auto operand = builder_.ConstantR4FromArray4D(o);
+ Array4D<float> e(4, 5, 17, 128);
+ e.FillWithPZ(pze);
+ Array4D<float> s(2, 2, 17, 128);
+ s.FillWithPZ(pzs);
+ auto source = builder_.ConstantR4FromArray4D(s);
+ s.FillWithPZ(pzs);
+ builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1},
+ Padding::kValid, source,
+ builder_.ConstantR0<float>(0.0f), add_f32_);
+ ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
+}
+
+TEST_F(SelectAndScatterTest, R4F32OverlapSmall) {
+ Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 8.0f},
+ {3.0f, 8.0f, 9.0f, 3.0f, 4.0f},
+ {1.0f, 5.0f, 7.0f, 5.0f, 6.0f},
+ {0.0f, 6.0f, 2.0f, 10.0f, 2.0f}};
+ Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
+ Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 0.0f},
+ {0.0f, 0.0f, 8.0f, 0.0f, 0.0f},
+ {0.0f, 0.0f, 3.0f, 0.0f, 0.0f},
+ {0.0f, 0.0f, 0.0f, 1.0f, 0.0f}};
+ Array4D<float> o(4, 5, 1, 1);
+ o.FillWithPZ(pzo);
+ auto operand = builder_.ConstantR4FromArray4D(o);
+ Array4D<float> e(4, 5, 1, 1);
+ e.FillWithPZ(pze);
+ Array4D<float> s(2, 2, 1, 1);
+ s.FillWithPZ(pzs);
+ auto source = builder_.ConstantR4FromArray4D(s);
+ s.FillWithPZ(pzs);
+ builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1},
+ Padding::kValid, source,
+ builder_.ConstantR0<float>(0.0f), add_f32_);
+ ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
+}
+
+TEST_F(SelectAndScatterTest, R4F32RefValidFixedSmall) {
+ // This test is testing the Reference Util
+ Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 2.0f},
+ {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f},
+ {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f},
+ {0.0f, 6.0f, 2.0f, 7.0f, 2.00f, 8.0f}};
+ Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
+ Array4D<float> o(4, 6, 4, 4);
+ o.FillWithPZ(pzo);
+ auto operand = builder_.ConstantR4FromArray4D(o);
+ Array4D<float> s(2, 2, 4, 4);
+ s.FillWithPZ(pzs);
+
+ auto source = builder_.ConstantR4FromArray4D(s);
+ s.FillWithPZ(pzs);
+ builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1},
+ Padding::kValid, source,
+ builder_.ConstantR0<float>(0.0f), add_f32_);
+ auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {2, 3, 1, 1},
+ {2, 3, 1, 1}, false);
+ ComputeAndCompareR4<float>(&builder_, *e, {}, ErrorSpec(1e-7));
+}
+
+TEST_F(SelectAndScatterTest, R4F32RefSameRandom) {
+ Array4D<float> o(7, 7, 8, 256);
+ o.FillRandom(1.5f);
+ auto operand = builder_.ConstantR4FromArray4D(o);
+
+ Array4D<float> s(4, 4, 8, 256);
+ s.FillRandom(12.0f);
+ auto source = builder_.ConstantR4FromArray4D(s);
+
+ builder_.SelectAndScatter(operand, ge_f32_, {2, 2, 1, 1}, {2, 2, 1, 1},
+ Padding::kSame, source,
+ builder_.ConstantR0<float>(0.0f), add_f32_);
+ auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {2, 2, 1, 1},
+ {2, 2, 1, 1}, true);
+
+ ComputeAndCompareR4<float>(&builder_, *e, {}, ErrorSpec(1e-7));
+}
+
+TEST_F(SelectAndScatterTest, R4F32RefSameRandomFullyPadded) {
+ Array4D<float> o(1, 1, 5, 5);
+ o.FillRandom(1.5f);
+ auto operand = builder_.ConstantR4FromArray4D(o);
+
+ Array4D<float> s(1, 1, 5, 5);
+ s.FillRandom(12.0f);
+ auto source = builder_.ConstantR4FromArray4D(s);
+
+ builder_.SelectAndScatter(operand, ge_f32_, {3, 3, 1, 1}, {3, 3, 1, 1},
+ Padding::kSame, source,
+ builder_.ConstantR0<float>(0.0f), add_f32_);
+ auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {3, 3, 1, 1},
+ {3, 3, 1, 1}, true);
+
+ ComputeAndCompareR4<float>(&builder_, *e, {}, ErrorSpec(1e-7));
+}
+
+TEST_F(SelectAndScatterTest, R4F32RefValidRandom) {
+ Array4D<float> o(9, 9, 16, 128);
+ o.FillRandom(1.5f);
+ auto operand = builder_.ConstantR4FromArray4D(o);
+
+ Array4D<float> s(3, 3, 16, 128);
+ s.FillRandom(12.0f);
+ auto source = builder_.ConstantR4FromArray4D(s);
+
+ builder_.SelectAndScatter(operand, ge_f32_, {3, 3, 1, 1}, {3, 3, 1, 1},
+ Padding::kValid, source,
+ builder_.ConstantR0<float>(0.0f), add_f32_);
+
+ auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {3, 3, 1, 1},
+ {3, 3, 1, 1}, false);
+
+ ComputeAndCompareR4<float>(&builder_, *e, {}, ErrorSpec(1e-7));
+}
+
+TEST_F(SelectAndScatterTest, R4F32RefValidRandomSmall) {
+ Array4D<float> o(3, 3, 4, 4);
+ o.FillRandom(1.5f);
+ auto operand = builder_.ConstantR4FromArray4D(o);
+
+ Array4D<float> s(1, 1, 4, 4);
+ s.FillRandom(12.0f);
+ auto source = builder_.ConstantR4FromArray4D(s);
+
+ builder_.SelectAndScatter(operand, ge_f32_, {3, 3, 1, 1}, {3, 3, 1, 1},
+ Padding::kValid, source,
+ builder_.ConstantR0<float>(0.0f), add_f32_);
+
+ auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {3, 3, 1, 1},
+ {3, 3, 1, 1}, false);
+
+ ComputeAndCompareR4<float>(&builder_, *e, {}, ErrorSpec(1e-7));
+}
+
+XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMaxScatter) {
+ const auto operand = builder_.ConstantR1<float>({1, 2, 3, 100, 3, 2, 1});
+ const auto source = builder_.ConstantR1<float>({34, 42, 53, 19});
+ const std::vector<float> expected = {0, 0, 0, 53, 0, 0, 0};
+ builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4},
+ /*window_strides=*/{1}, Padding::kValid, source,
+ builder_.ConstantR0<float>(0), max_f32_);
+ ComputeAndCompareR1<float>(&builder_, expected, {}, ErrorSpec(1e-7));
+}
+
+XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMinScatter) {
+ const auto operand = builder_.ConstantR1<float>({1, 2, 3, 100, 3, 2, 1});
+ const auto source = builder_.ConstantR1<float>({34, 42, 53, 19});
+ const float max_float = std::numeric_limits<float>::max();
+ const std::vector<float> expected = {max_float, max_float, max_float, 19,
+ max_float, max_float, max_float};
+ builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4},
+ /*window_strides=*/{1}, Padding::kValid, source,
+ builder_.ConstantR0<float>(max_float), min_f32_);
+ ComputeAndCompareR1<float>(&builder_, expected, {}, ErrorSpec(1e-7));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/select_test.cc b/tensorflow/compiler/xla/tests/select_test.cc
new file mode 100644
index 0000000000..5ec9ac95fa
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/select_test.cc
@@ -0,0 +1,276 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class SelectTest : public ClientLibraryTestBase {
+ public:
+ ErrorSpec error_spec_{0.0001};
+};
+
+TEST_F(SelectTest, SelectScalarF32True) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR0<bool>(true);
+ auto on_true = builder.ConstantR0<float>(123.0f);
+ auto on_false = builder.ConstantR0<float>(42.0f);
+ auto result = builder.Select(pred, on_true, on_false);
+
+ ComputeAndCompareR0<float>(&builder, 123.0f, {}, error_spec_);
+}
+
+TEST_F(SelectTest, SelectScalarS32True) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR0<bool>(true);
+ auto on_true = builder.ConstantR0<int32>(-42);
+ auto on_false = builder.ConstantR0<int32>(42);
+ auto result = builder.Select(pred, on_true, on_false);
+
+ ComputeAndCompareR0<int32>(&builder, -42, {});
+}
+
+TEST_F(SelectTest, SelectScalarF32False) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR0<bool>(false);
+ auto on_true = builder.ConstantR0<float>(123.0f);
+ auto on_false = builder.ConstantR0<float>(42.0f);
+ auto result = builder.Select(pred, on_true, on_false);
+
+ ComputeAndCompareR0<float>(&builder, 42.0f, {}, error_spec_);
+}
+
+XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR1<bool>({});
+ auto on_true = builder.ConstantR1<float>({});
+ auto on_false = builder.ConstantR1<float>({});
+ auto select = builder.Select(pred, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR1<bool>({false, true, false, true, false});
+ auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
+ auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
+ auto select = builder.Select(pred, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {},
+ error_spec_);
+}
+
+XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) {
+ // Similar to SelectR1S0F32WithConstantR1S0PRED, except that the pred vector
+ // is not a constant, but rather the result of comparing two other vectors.
+ ComputationBuilder builder(client_, TestName());
+ auto v1 = builder.ConstantR1<int32>({});
+ auto v2 = builder.ConstantR1<int32>({});
+ auto cmp = builder.Eq(v1, v2);
+ auto on_true = builder.ConstantR1<float>({});
+ auto on_false = builder.ConstantR1<float>({});
+ auto select = builder.Select(cmp, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) {
+ // Similar to SelectR1F32WithConstantR1PRED, except that the pred vector is
+ // not a constant, but rather the result of comparing two other vectors.
+ ComputationBuilder builder(client_, TestName());
+ auto v1 = builder.ConstantR1<int32>({1, 2, 3, 4, 5});
+ auto v2 = builder.ConstantR1<int32>({9, 2, 9, 4, 9});
+ auto cmp = builder.Eq(v1, v2);
+ auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
+ auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
+ auto select = builder.Select(cmp, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {},
+ error_spec_);
+}
+
+TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) {
+ // Similar to SelectR1F32WithCmpR1S32s, except "gt"-comparing two R1F32s.
+ ComputationBuilder builder(client_, TestName());
+ auto v1 = builder.ConstantR1<float>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
+ auto v2 = builder.ConstantR1<float>({-1.0f, -2.0f, 13.0f, 14.0f, 4.4f});
+ auto cmp = builder.Gt(v1, v2);
+ auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
+ auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
+ auto select = builder.Select(cmp, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f, 1.0f, 10.0f, 6.0f}, {},
+ error_spec_);
+}
+
+TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) {
+ // Selects among two R1F32s, which come from parameters. v1 and v2 are
+ // compared, and selection between them happens based on a gt-comparison mask.
+ ComputationBuilder builder(client_, TestName());
+
+ ComputationDataHandle v1, v2;
+ std::unique_ptr<GlobalData> param0_data = CreateR1Parameter<float>(
+ {41.0f, 2.0f, 3.0f, 84.0f}, /*parameter_number=*/0, /*name=*/"v1",
+ /*builder=*/&builder, /*data_handle=*/&v1);
+ std::unique_ptr<GlobalData> param1_data = CreateR1Parameter<float>(
+ {21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2",
+ /*builder=*/&builder, /*data_handle=*/&v2);
+
+ auto cmp = builder.Gt(v1, v2);
+ auto select = builder.Select(cmp, v1, v2);
+ ComputeAndCompareR1<float>(&builder, {41.0f, 22.0f, 23.0f, 84.0f},
+ {param0_data.get(), param1_data.get()},
+ error_spec_);
+}
+
+TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) {
+ // Similar to SelectR1F32WithCmpR1F32sFromParamsSmall, except that the
+ // data size passed in and out is large.
+ ComputationBuilder builder(client_, TestName());
+
+ // Number of floats in the data passed into and out of the computation.
+ constexpr int datalen = 15 * 1000;
+
+ // The inputs are initialized with a special pattern where in the first third
+ // of the data v1[i] > v2[i] and elsewhere it's vice versa.
+ std::vector<float> v1vec;
+ std::vector<float> v2vec;
+ std::vector<float> expected_vec;
+ for (int i = 0; i < datalen; ++i) {
+ float smaller = i;
+ float larger = i * 2;
+ if (i < datalen / 3) {
+ v1vec.push_back(larger);
+ v2vec.push_back(smaller);
+ } else {
+ v1vec.push_back(smaller);
+ v2vec.push_back(larger);
+ }
+ expected_vec.push_back(larger);
+ }
+
+ ComputationDataHandle v1, v2;
+ std::unique_ptr<GlobalData> param0_data =
+ CreateR1Parameter<float>(v1vec, /*parameter_number=*/0, /*name=*/"v1",
+ /*builder=*/&builder, /*data_handle=*/&v1);
+ std::unique_ptr<GlobalData> param1_data =
+ CreateR1Parameter<float>(v2vec, /*parameter_number=*/1, /*name=*/"v2",
+ /*builder=*/&builder, /*data_handle=*/&v2);
+
+ auto cmp = builder.Gt(v1, v2);
+ auto select = builder.Select(cmp, v1, v2);
+ ComputeAndCompareR1<float>(&builder, expected_vec,
+ {param0_data.get(), param1_data.get()},
+ error_spec_);
+}
+
+TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) {
+ // "gt"-compares a R1S32 with a S32 scalar, and uses the resulting R1PRED to
+ // select between two R1F32s.
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<int32>({1, -1, 2, -2});
+ auto s = builder.ConstantR0<int32>(0);
+ auto cmp = builder.Gt(v, s);
+
+ auto on_true = builder.ConstantR1<float>({11.0f, 22.0f, 33.0f, 44.0f});
+ auto on_false =
+ builder.ConstantR1<float>({-111.0f, -222.0f, -333.0f, -444.0f});
+ auto select = builder.Select(cmp, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {11.0f, -222.0f, 33.0f, -444.0f}, {},
+ error_spec_);
+}
+
+TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) {
+ // "gt"-compares a R1F32 with a F32 scalar, and uses the resulting R1PRED to
+ // select between two R1F32s.
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
+ auto s = builder.ConstantR0<float>(2.5f);
+ auto cmp = builder.Gt(v, s);
+
+ auto on_true = builder.ConstantR1<float>({11.0f, 22.0f, 33.0f, 44.0f});
+ auto on_false =
+ builder.ConstantR1<float>({-111.0f, -222.0f, -333.0f, -444.0f});
+ auto select = builder.Select(cmp, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {-111.0f, -222.0f, 33.0f, 44.0f}, {},
+ error_spec_);
+}
+
+XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) {
+ for (bool which : {false, true}) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR0<bool>(which);
+ auto on_true = builder.ConstantR1<float>({});
+ auto on_false = builder.ConstantR1<float>({});
+ auto select = builder.Select(pred, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+ }
+}
+
+TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR0<bool>(true);
+ auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f});
+ auto on_false = builder.ConstantR1<float>({10.0f, 5.0f});
+ auto select = builder.Select(pred, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f}, {}, error_spec_);
+}
+
+TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR0<bool>(false);
+ auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f});
+ auto on_false = builder.ConstantR1<float>({10.0f, 5.0f});
+ auto select = builder.Select(pred, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {10.0f, 5.0f}, {}, error_spec_);
+}
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/set_return_value_test.cc b/tensorflow/compiler/xla/tests/set_return_value_test.cc
new file mode 100644
index 0000000000..e15d744d95
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/set_return_value_test.cc
@@ -0,0 +1,116 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class SetReturnValueTest : public ClientLibraryTestBase {};
+
+TEST_F(SetReturnValueTest, NoSetValue) {
+ ComputationBuilder builder(client_, "no_set_value");
+ auto alpha = builder.ConstantR0<float>(1.0);
+ auto x = builder.ConstantR1<float>(
+ {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0});
+ auto ax = builder.Add(alpha, x);
+ auto aax = builder.Add(alpha, ax);
+
+ std::vector<float> expected = {1.0, 3.0, 4.0, 0.0, -1.0,
+ 5.0, 6.0, -2.0, -3.0, 7.0};
+
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(SetReturnValueTest, SetValue) {
+ ComputationBuilder builder(client_, "set_value");
+ auto alpha = builder.ConstantR0<float>(1.0);
+ auto x = builder.ConstantR1<float>(
+ {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0});
+ auto ax = builder.Add(alpha, x);
+ auto aax = builder.Add(alpha, ax);
+ auto builder_status = builder.SetReturnValue(ax);
+ EXPECT_TRUE(builder_status.ok());
+
+ std::vector<float> expected = {0.0, 2.0, 3.0, -1.0, -2.0,
+ 4.0, 5.0, -3.0, -4.0, 6.0};
+
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(SetReturnValueTest, SetValueAndModify) {
+ ComputationBuilder builder(client_, "set_value_and_modify");
+ auto alpha = builder.ConstantR0<float>(1.0);
+ auto x = builder.ConstantR1<float>(
+ {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0});
+ auto ax = builder.Add(alpha, x);
+ auto aax = builder.Add(alpha, ax);
+ auto builder_status = builder.SetReturnValue(ax);
+ EXPECT_TRUE(builder_status.ok());
+ auto aaax = builder.Add(alpha, aax);
+
+ std::vector<float> expected = {0.0, 2.0, 3.0, -1.0, -2.0,
+ 4.0, 5.0, -3.0, -4.0, 6.0};
+
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(SetReturnValueTest, SetValueMultipleTimesAndModify) {
+ ComputationBuilder builder(client_, "set_value_multiple_times_and_modify");
+ auto alpha = builder.ConstantR0<float>(1.0);
+ auto x = builder.ConstantR1<float>(
+ {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0});
+ auto ax = builder.Add(alpha, x);
+ auto aax = builder.Add(alpha, ax);
+ auto builder_status = builder.SetReturnValue(aax);
+ EXPECT_TRUE(builder_status.ok());
+ auto aaax = builder.Add(alpha, aax);
+ builder_status = builder.SetReturnValue(ax);
+ EXPECT_TRUE(builder_status.ok());
+ auto aaaax = builder.Add(alpha, aaax);
+
+ std::vector<float> expected = {0.0, 2.0, 3.0, -1.0, -2.0,
+ 4.0, 5.0, -3.0, -4.0, 6.0};
+
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
new file mode 100644
index 0000000000..d63582fb98
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -0,0 +1,277 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests that slice operations can be performed.
+
+#include <numeric>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class SliceTest : public ClientLibraryTestBase {
+ protected:
+ template <typename NativeT>
+ void RunSliceTenToTwo() {
+ std::vector<NativeT> constant;
+ for (int i = 0; i < 10; ++i) {
+ constant.push_back(static_cast<NativeT>(i));
+ }
+
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR1<NativeT>(constant);
+ builder.Slice(original, {2}, {4});
+
+ const std::vector<NativeT> expected = {static_cast<NativeT>(2),
+ static_cast<NativeT>(3)};
+ ComputeAndCompareR1<NativeT>(&builder, expected, {});
+ }
+};
+
+XLA_TEST_F(SliceTest, SliceZeroToZeroF32) {
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR1<float>({});
+ builder.Slice(original, {0}, {0});
+
+ ComputeAndCompareR1<float>(&builder, {}, {});
+}
+
+XLA_TEST_F(SliceTest, SliceTenToZeroF32) {
+ ComputationBuilder builder(client_, TestName());
+ std::vector<float> constant(10, 0.3);
+ auto original = builder.ConstantR1<float>(constant);
+ builder.Slice(original, {7}, {7});
+
+ ComputeAndCompareR1<float>(&builder, {}, {});
+}
+
+TEST_F(SliceTest, SliceTenToTwoF32) { RunSliceTenToTwo<float>(); }
+
+XLA_TEST_F(SliceTest, SliceTenToTwoF64) { RunSliceTenToTwo<double>(); }
+
+TEST_F(SliceTest, SliceTenToTwoU32) { RunSliceTenToTwo<uint32>(); }
+
+TEST_F(SliceTest, SliceTenToTwoS32) { RunSliceTenToTwo<int32>(); }
+
+XLA_TEST_F(SliceTest, SliceTenToTwoU64) { RunSliceTenToTwo<uint64>(); }
+
+XLA_TEST_F(SliceTest, SliceTenToTwoS64) { RunSliceTenToTwo<int64>(); }
+
+TEST_F(SliceTest, SliceTenToTen) {
+ const std::vector<float> values = {0.0, 1.0, 2.0, 3.0, 4.0,
+ 5.0, 6.0, 7.0, 8.0, 9.0};
+
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR1<float>(values);
+ builder.Slice(original, {0}, {10});
+
+ ComputeAndCompareR1<float>(&builder, values, {}, ErrorSpec(0.000001));
+}
+
+TEST_F(SliceTest, SliceLastFourOf1024) {
+ std::vector<float> values(1024);
+ std::iota(values.begin(), values.end(), 0.0);
+
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR1<float>(values);
+ builder.Slice(original, {1024 - 4}, {1024});
+
+ const std::vector<float> expected = {1020, 1021, 1022, 1023};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.000001));
+}
+
+// TODO(b/28491443): Fix wrong result on CPU and GPU. Failed on
+// 2016-05-01. Also b/28508652
+TEST_F(SliceTest, DISABLED_SliceUnaligned1024In4096Values) {
+ std::vector<float> values(4096);
+ std::iota(values.begin(), values.end(), 0.0);
+
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR1<float>(values);
+ builder.Slice(original, {7}, {7 + 1024});
+
+ std::vector<float> expected(1024);
+ std::iota(values.begin(), values.end(), 7.0);
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.000001));
+}
+
+XLA_TEST_F(SliceTest, Slice0x0to0x0F32) {
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 0));
+ builder.Slice(original, {0, 0}, {0, 0});
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {});
+}
+
+XLA_TEST_F(SliceTest, Slice0x20to0x5F32) {
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 20));
+ builder.Slice(original, {0, 15}, {0, 20});
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 5), {});
+}
+
+XLA_TEST_F(SliceTest, Slice3x0to2x0F32) {
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(3, 0));
+ builder.Slice(original, {1, 0}, {3, 0});
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(2, 0), {});
+}
+
+XLA_TEST_F(SliceTest, SliceQuadrantOf256x256) {
+ Array2D<float> values(256, 256);
+ for (int row = 0; row < 256; ++row) {
+ for (int col = 0; col < 256; ++col) {
+ values(row, col) = (row << 10) | col;
+ }
+ }
+
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR2FromArray2D<float>(values);
+ builder.Slice(original, {128, 128}, {256, 256});
+
+ Array2D<float> expected(128, 128);
+ for (int row = 0; row < 128; ++row) {
+ for (int col = 0; col < 128; ++col) {
+ expected(row, col) = ((row + 128) << 10) | (col + 128);
+ }
+ }
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
+}
+
+// Tests: (f32[1,4096], starts={0, 3072}, limits={1, 4096}) -> f32[1,1024])
+TEST_F(SliceTest, Slice_1x4096_To_1x1024) {
+ Array2D<float> values(1, 4096);
+ std::iota(values.data(), values.data() + 4096, 0.0);
+
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR2FromArray2D<float>(values);
+ builder.Slice(original, {0, 3072}, {1, 4096});
+
+ Array2D<float> expected(1, 1024);
+ std::iota(expected.data(), expected.data() + 1024, 3072.0);
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
+}
+
+// Tests slice: (f32[16,4], starts={0, 0}, limits={16, 2}) -> f32[16,2]
+TEST_F(SliceTest, Slice_16x4_To_16x2) {
+ Array2D<float> values(16, 4);
+ Array2D<float> expected(16, 2);
+ for (int row = 0; row < 16; ++row) {
+ for (int col = 0; col < 4; ++col) {
+ values(row, col) = (row << 10) | col;
+ if (col < 2) {
+ expected(row, col) = (row << 10) | col;
+ }
+ }
+ }
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR2FromArray2D<float>(values);
+ builder.Slice(original, {0, 0}, {16, 2});
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
+}
+
+// Tests: (f32[2, 2, 24, 256], starts = {1, 0, 8, 0}, ends = {2, 2, 16, 128}
+TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) {
+ Array4D<float> values(2, 2, 24, 256);
+ values.FillRandom(3.14f);
+ auto expected =
+ ReferenceUtil::Slice4D(values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}});
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR4FromArray4D(values);
+ builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128});
+ ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001));
+}
+
+struct R2Spec {
+ int64 input_dim0;
+ int64 input_dim1;
+ std::array<int64, 2> slice_starts;
+ std::array<int64, 2> slice_limits;
+ Layout layout;
+};
+
+// Parameterized test that generates patterned R2 values, slices them according
+// to the R2Spec, and compares the results with the ReferenceUtil version.
+class SliceR2Test : public ClientLibraryTestBase,
+ public ::testing::WithParamInterface<R2Spec> {};
+
+TEST_P(SliceR2Test, DoIt) {
+ const R2Spec& spec = GetParam();
+ Array2D<int32> input(spec.input_dim0, spec.input_dim1);
+ input.FillUnique();
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2FromArray2D<int32>(input);
+ builder.Slice(a, spec.slice_starts, spec.slice_limits);
+
+ std::unique_ptr<Array2D<int32>> expected =
+ ReferenceUtil::Slice2D(input, spec.slice_starts, spec.slice_limits);
+ ComputeAndCompareR2<int32>(&builder, *expected, {});
+}
+
+// clang-format off
+INSTANTIATE_TEST_CASE_P(
+ SliceR2TestInstantiation, SliceR2Test,
+ ::testing::Values(
+ R2Spec {4, 12, {{0, 3}}, {{4, 6}}, LayoutUtil::MakeLayout({0, 1})},
+ R2Spec {4, 12, {{0, 3}}, {{4, 6}}, LayoutUtil::MakeLayout({1, 0})},
+ R2Spec {16, 4, {{0, 2}}, {{16, 4}}, LayoutUtil::MakeLayout({0, 1})},
+ R2Spec {16, 4, {{0, 2}}, {{16, 4}}, LayoutUtil::MakeLayout({1, 0})},
+ R2Spec {256, 400, {{0, 300}}, {{256, 400}},
+ LayoutUtil::MakeLayout({1, 0})},
+ R2Spec {500, 400, {{111, 123}}, {{300, 257}},
+ LayoutUtil::MakeLayout({1, 0})},
+ R2Spec {500, 400, {{111, 123}}, {{300, 400}},
+ LayoutUtil::MakeLayout({1, 0})},
+ R2Spec {384, 512, {{128, 256}}, {{256, 384}},
+ LayoutUtil::MakeLayout({1, 0})},
+ R2Spec {357, 512, {{111, 256}}, {{301, 384}},
+ LayoutUtil::MakeLayout({1, 0})}
+ )
+);
+// clang-format on
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h
new file mode 100644
index 0000000000..7f987a21ca
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/test_macros.h
@@ -0,0 +1,76 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Macros for use in enabling/disabling tests on particular
+// platforms. Marking a gunit test as disabled still ensures that it
+// compiles.
+//
+// Implementation note: the macros are structured as follows:
+// * Define the disabled macro to just pass the test name through (which, in
+// effect, does not disable it at all)
+// * If a XLA_TEST_BACKEND_$TARGET macro indicates we're compiling for
+// $TARGET platform, make the disabled macro truly disable the test; i.e. by
+// redefining the DISABLED_ON_$TARGET macro to prepend "DISABLED_" to the test
+// name.
+
+#ifndef TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
+#define TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
+
+#include <string>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/test.h"
+
+// Use this macro instead of directly using TEST_P for parameterized tests,
+// otherwise DISABLED_ON_* macros nested in TEST_P will not get expanded since
+// TEST_P stringifies its argument. That makes the test disabled for all targets
+// when any one of the DISABLED_ON_* macro is used, and the test will just pass.
+// TODO(b/29122096): Remove this once TEST_P fixes this problem.
+#define XLA_TEST_P(test_case_name, test_name) TEST_P(test_case_name, test_name)
+
+#define DISABLED_ON_CPU(X) X
+#define DISABLED_ON_CPU_PARALLEL(X) X
+#define DISABLED_ON_GPU(X) X
+
+// We need this macro instead of pasting directly to support nesting
+// the DISABLED_ON_FOO macros, as in the definition of DISABLED_ON_CPU.
+// Otherwise the pasting is applied before macro expansion completes.
+#define XLA_TEST_PASTE(A, B) A##B
+
+// We turn off clang-format so we can indent the macros for readability.
+// clang-format off
+
+#ifdef XLA_TEST_BACKEND_CPU
+# undef DISABLED_ON_CPU
+# define DISABLED_ON_CPU(X) XLA_TEST_PASTE(DISABLED_, X)
+#endif // XLA_TEST_BACKEND_CPU
+
+#ifdef XLA_TEST_BACKEND_CPU_PARALLEL
+# undef DISABLED_ON_CPU
+# define DISABLED_ON_CPU(X) XLA_TEST_PASTE(DISABLED_, X)
+# undef DISABLED_ON_CPU_PARALLEL
+# define DISABLED_ON_CPU_PARALLEL(X) XLA_TEST_PASTE(DISABLED_, X)
+#endif // XLA_TEST_BACKEND_CPU_PARALLEL
+
+#ifdef XLA_TEST_BACKEND_GPU
+# undef DISABLED_ON_GPU
+# define DISABLED_ON_GPU(X) XLA_TEST_PASTE(DISABLED_, X)
+#endif // XLA_TEST_BACKEND_GPU
+
+// clang-format on
+
+#define XLA_TEST_F(test_fixture, test_name) TEST_F(test_fixture, test_name)
+
+#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h
new file mode 100644
index 0000000000..6a23df4d3c
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/test_utils.h
@@ -0,0 +1,115 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_
+#define TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_
+
+#include <initializer_list>
+#include <memory>
+#include <random>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace test_utils {
+
+// A class which generates pseudorandom numbers of a given type within a given
+// range. Not cryptographically secure and likely not perfectly evenly
+// distributed across the range but sufficient for most tests.
+template <typename NativeT>
+class PseudorandomGenerator {
+ public:
+ explicit PseudorandomGenerator(NativeT min_value, NativeT max_value,
+ uint32 seed)
+ : min_(min_value), max_(max_value), generator_(seed) {}
+
+ // Get a pseudorandom value.
+ NativeT get() {
+ std::uniform_real_distribution<> distribution;
+ return static_cast<NativeT>(min_ +
+ (max_ - min_) * distribution(generator_));
+ }
+
+ private:
+ NativeT min_;
+ NativeT max_;
+ std::mt19937 generator_;
+};
+
+// Convenience function for creating a rank-2 array with arbitrary layout.
+template <typename NativeT>
+std::unique_ptr<Literal> CreateR2LiteralWithLayout(
+ std::initializer_list<std::initializer_list<NativeT>> values,
+ tensorflow::gtl::ArraySlice<int64> minor_to_major) {
+ auto literal = MakeUnique<Literal>();
+ const int64 d0 = values.size();
+ const int64 d1 = values.begin()->size();
+ LiteralUtil::PopulateWithValue<NativeT>(0, {d0, d1}, literal.get());
+ *literal->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout(minor_to_major);
+ TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape()));
+
+ int64 dim0 = 0;
+ for (auto inner_list : values) {
+ int64 dim1 = 0;
+ for (auto value : inner_list) {
+ LiteralUtil::Set(literal.get(), {dim0, dim1}, value);
+ ++dim1;
+ }
+ ++dim0;
+ }
+ return literal;
+}
+
+// Convenience function for creating a rank-3 array with arbitrary layout.
+template <typename NativeT>
+std::unique_ptr<Literal> CreateR3LiteralWithLayout(
+ std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
+ values,
+ tensorflow::gtl::ArraySlice<int64> minor_to_major) {
+ auto literal = MakeUnique<Literal>();
+ const int64 d0 = values.size();
+ const int64 d1 = values.begin()->size();
+ const int64 d2 = values.begin()->begin()->size();
+ LiteralUtil::PopulateWithValue<NativeT>(0, {d0, d1, d2}, literal.get());
+ *literal->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout(minor_to_major);
+ TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape()));
+
+ int64 dim0 = 0;
+ for (auto inner_list : values) {
+ int64 dim1 = 0;
+ for (auto inner_inner_list : inner_list) {
+ int64 dim2 = 0;
+ for (auto value : inner_inner_list) {
+ LiteralUtil::Set(literal.get(), {dim0, dim1, dim2}, value);
+ ++dim2;
+ }
+ ++dim1;
+ }
+ ++dim0;
+ }
+ return literal;
+}
+
+} // namespace test_utils
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_
diff --git a/tensorflow/compiler/xla/tests/transpose_test.cc b/tensorflow/compiler/xla/tests/transpose_test.cc
new file mode 100644
index 0000000000..79f251bbc4
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/transpose_test.cc
@@ -0,0 +1,203 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class TransposeTest : public ClientLibraryTestBase {
+ public:
+ ErrorSpec error_spec_{0.0001};
+
+ protected:
+ void TestTransposeConstant021(size_t n1, size_t n2, size_t n3);
+};
+
+XLA_TEST_F(TransposeTest, Transpose0x0) {
+ ComputationBuilder builder(client_, "Transpose");
+ auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 0));
+ auto result = builder.Transpose(lhs, {1, 0});
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {}, error_spec_);
+}
+
+XLA_TEST_F(TransposeTest, Transpose0x42) {
+ ComputationBuilder builder(client_, "Transpose");
+ auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 42));
+ auto result = builder.Transpose(lhs, {1, 0});
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(42, 0), {}, error_spec_);
+}
+
+XLA_TEST_F(TransposeTest, Transpose7x0) {
+ ComputationBuilder builder(client_, "Transpose");
+ auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(7, 0));
+ auto result = builder.Transpose(lhs, {1, 0});
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 7), {}, error_spec_);
+}
+
+TEST_F(TransposeTest, Transpose2x2) {
+ ComputationBuilder builder(client_, "Transpose");
+ auto lhs = builder.ConstantR2<float>({
+ {1.0, 2.0}, {3.0, 4.0},
+ });
+ auto result = builder.Transpose(lhs, {1, 0});
+
+ Array2D<float> expected({{1.0f, 3.0f}, {2.0f, 4.0f}});
+
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(TransposeTest, Transpose0x2x3_2x3x0) {
+ ComputationBuilder builder(client_, "Transpose");
+ auto operand = builder.ConstantR3FromArray3D<int32>(Array3D<int32>(0, 2, 3));
+ auto result = builder.Transpose(operand, {1, 2, 0});
+
+ ComputeAndCompareR3<int32>(&builder, Array3D<int32>(2, 3, 0), {});
+}
+
+TEST_F(TransposeTest, Transpose1x2x3_2x3x1) {
+ ComputationBuilder builder(client_, "Transpose");
+ auto operand = builder.ConstantR3FromArray3D<int32>({{{1, 2, 3}, {4, 5, 6}}});
+ auto result = builder.Transpose(operand, {1, 2, 0});
+
+ Array3D<int32> expected({{{1}, {2}, {3}}, {{4}, {5}, {6}}});
+
+ ComputeAndCompareR3<int32>(&builder, expected, {});
+}
+
+TEST_F(TransposeTest, Transpose1x2x3_3x2x1) {
+ ComputationBuilder builder(client_, "Transpose");
+ auto operand = builder.ConstantR3FromArray3D<int32>({{{1, 2, 3}, {4, 5, 6}}});
+ auto result = builder.Transpose(operand, {2, 1, 0});
+
+ Array3D<int32> expected({{{1}, {4}}, {{2}, {5}}, {{3}, {6}}});
+
+ ComputeAndCompareR3<int32>(&builder, expected, {});
+}
+
+TEST_F(TransposeTest, Transpose1x2x3_1x2x3) {
+ ComputationBuilder builder(client_, "Transpose");
+ auto operand = builder.ConstantR3FromArray3D<int32>({{{1, 2, 3}, {4, 5, 6}}});
+ auto result = builder.Transpose(operand, {0, 1, 2});
+
+ Array3D<int32> expected({{{1, 2, 3}, {4, 5, 6}}});
+
+ ComputeAndCompareR3<int32>(&builder, expected, {});
+}
+
+TEST_F(TransposeTest, MultiTranspose3x2) {
+ Array2D<float> input({{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}});
+ Array2D<float> transposed({{1.0f, 3.0f, 5.0f}, {2.0f, 4.0f, 6.0f}});
+
+ for (int transposes = 0; transposes <= 10; ++transposes) {
+ ComputationBuilder builder(client_, "Transpose");
+ auto computed = builder.ConstantR2FromArray2D<float>(input);
+ for (int i = 0; i < transposes; ++i) {
+ computed = builder.Transpose(computed, {1, 0});
+ }
+ const Array2D<float>& expected = transposes % 2 == 0 ? input : transposed;
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+ }
+}
+
+// Test for transposing [1x1] matrix.
+TEST_F(TransposeTest, Small_1x1) {
+ auto aoperand = MakeLinspaceArray2D(0.0, 1.0, 1, 1);
+
+ ComputationBuilder builder(client_, "transpose_1x1");
+ auto operand = builder.ConstantR2FromArray2D<float>(*aoperand);
+ builder.Transpose(operand, {1, 0});
+
+ auto expected = ReferenceUtil::TransposeArray2D(*aoperand);
+ ComputeAndCompareR2<float>(&builder, *expected, {}, ErrorSpec(1e-4));
+}
+
+// Test for transposing [2x2] matrix.
+TEST_F(TransposeTest, Small_2x2) {
+ auto aoperand = MakeLinspaceArray2D(0.0, 4.0, 2, 2);
+
+ ComputationBuilder builder(client_, "transpose_2x2");
+ auto operand = builder.ConstantR2FromArray2D<float>(*aoperand);
+ builder.Transpose(operand, {1, 0});
+
+ auto expected = ReferenceUtil::TransposeArray2D(*aoperand);
+ ComputeAndCompareR2<float>(&builder, *expected, {}, ErrorSpec(1e-4));
+}
+
+void TransposeTest::TestTransposeConstant021(size_t n1, size_t n2, size_t n3) {
+ Array3D<int32> aoperand(n1, n2, n3);
+ Array3D<int32> expected(n1, n3, n2);
+ for (size_t i = 0; i < n1; ++i) {
+ for (size_t j = 0; j < n2; ++j) {
+ for (size_t k = 0; k < n3; ++k) {
+ aoperand(i, j, k) = i * n3 * n2 + j * n3 + k;
+ expected(i, k, j) = aoperand(i, j, k);
+ }
+ }
+ }
+
+ ComputationBuilder builder(client_, TestName());
+ auto operand = builder.ConstantR3FromArray3D(aoperand);
+ builder.Transpose(operand, {0, 2, 1});
+
+ ComputeAndCompareR3<int32>(&builder, expected, {});
+}
+
+TEST_F(TransposeTest, TransposeConstant021_SingleIncompleteTilePerLayer) {
+ TestTransposeConstant021(2, 2, 3);
+}
+
+TEST_F(TransposeTest, TransposeConstant021_SingleCompleteTilePerLayer) {
+ TestTransposeConstant021(2, 32, 32);
+}
+
+TEST_F(TransposeTest, TransposeConstant021_MultipleTilesPerLayer) {
+ TestTransposeConstant021(2, 70, 35);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc
new file mode 100644
index 0000000000..cea9316a6d
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/tuple_test.cc
@@ -0,0 +1,415 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <initializer_list>
+#include <memory>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class TupleTest : public ClientLibraryTestBase {
+ public:
+ ErrorSpec error_spec_{0.0001};
+};
+
+// Tests the creation of tuple data.
+XLA_TEST_F(TupleTest, TupleCreate) {
+ ComputationBuilder builder(client_, TestName());
+
+ const float constant_scalar = 7.3f;
+ std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
+ std::initializer_list<std::initializer_list<float>> constant_matrix = {
+ {1.1f, 2.2f, 3.5f}, // row 0
+ {4.8f, 5.0f, 6.7f}, // row 1
+ };
+ auto result = builder.Tuple({builder.ConstantR0<float>(constant_scalar),
+ builder.ConstantR1<float>(constant_vector),
+ builder.ConstantR2<float>(constant_matrix)});
+
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<float>(constant_scalar).get(),
+ LiteralUtil::CreateR1<float>(constant_vector).get(),
+ LiteralUtil::CreateR2<float>(constant_matrix).get()});
+ ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+}
+
+// Tests the creation of tuple data.
+XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) {
+ ComputationBuilder builder(client_, TestName());
+
+ auto result = builder.Tuple(
+ {builder.ConstantR0<float>(7.0), builder.ConstantR1<float>({})});
+
+ auto expected =
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(7.0).get(),
+ LiteralUtil::CreateR1<float>({}).get()});
+ ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+}
+
+// Tests the creation of an empty tuple.
+XLA_TEST_F(TupleTest, EmptyTupleCreate) {
+ ComputationBuilder builder(client_, TestName());
+ auto result = builder.Tuple({});
+ auto expected = LiteralUtil::MakeTuple({});
+ ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+}
+
+// Trivial test for extracting a tuple element with GetTupleElement.
+XLA_TEST_F(TupleTest, GetTupleElement) {
+ ComputationBuilder builder(client_, TestName());
+ std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
+ std::initializer_list<std::initializer_list<float>> constant_matrix = {
+ {1.f, 2.f, 3.f}, // row 0
+ {4.f, 5.f, 6.f}, // row 1
+ };
+ auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
+ builder.ConstantR2<float>(constant_matrix)});
+ auto matrix_element = builder.GetTupleElement(tuple_data, 1);
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(constant_matrix), {},
+ error_spec_);
+}
+
+// Trivial test for extracting a tuple element with GetTupleElement.
+XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) {
+ ComputationBuilder builder(client_, TestName());
+ auto tuple_data = builder.Tuple(
+ {builder.ConstantR1<float>({}),
+ builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 101))});
+ auto matrix_element = builder.GetTupleElement(tuple_data, 1);
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 101), {}, error_spec_);
+}
+
+// Extracts both elements from a tuple with GetTupleElement and then adds them
+// together.
+XLA_TEST_F(TupleTest, AddTupleElements) {
+ ComputationBuilder builder(client_, TestName());
+ std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
+ std::initializer_list<std::initializer_list<float>> constant_matrix = {
+ {1.f, 2.f, 3.f}, // row 0
+ {4.f, 5.f, 6.f}, // row 1
+ };
+ auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
+ builder.ConstantR2<float>(constant_matrix)});
+ auto vector_element = builder.GetTupleElement(tuple_data, 0);
+ auto matrix_element = builder.GetTupleElement(tuple_data, 1);
+ auto vector_shape = builder.GetShape(vector_element).ConsumeValueOrDie();
+ auto matrix_shape = builder.GetShape(matrix_element).ConsumeValueOrDie();
+ auto result = builder.Add(matrix_element, vector_element,
+ /*broadcast_dimensions=*/{1});
+
+ Array2D<float> expected({
+ {2.f, 4.f, 6.f}, // row 0
+ {5.f, 7.f, 9.f}, // row 1
+ });
+ ASSERT_TRUE(ShapeUtil::ShapeIs(*vector_shape, F32, {3}));
+ ASSERT_TRUE(ShapeUtil::ShapeIs(*matrix_shape, F32, {/*y=*/2, /*x=*/3}));
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// Extracts both elements from a tuple and then puts them into a new tuple in
+// the opposite order.
+XLA_TEST_F(TupleTest, TupleGTEToTuple) {
+ ComputationBuilder builder(client_, TestName());
+ std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
+ std::initializer_list<std::initializer_list<float>> constant_matrix = {
+ {1.f, 2.f, 3.f}, // row 0
+ {4.f, 5.f, 6.f}, // row 1
+ };
+ auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
+ builder.ConstantR2<float>(constant_matrix)});
+ auto new_tuple = builder.Tuple({builder.GetTupleElement(tuple_data, 1),
+ builder.GetTupleElement(tuple_data, 0)});
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>(constant_matrix).get(),
+ LiteralUtil::CreateR1<float>(constant_vector).get()});
+ ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+}
+
+// Builds two new tuples from an existing tuple (by means of GetTupleElement),
+// then adds up the components of the new tuples.
+XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) {
+ //
+ // v------ --(GTE 0)-- --(GTE 0)----------
+ // \ / \ / \
+ // (tuple)-- (tuple01)-- \
+ // / | \ / \ \
+ // m------ | --(GTE 1)-- --(GTE 1)------------ \
+ // | \ \
+ // | (add)
+ // | / /
+ // |--------(GTE 1)-- --(GTE 0)------------ /
+ // \ \ / /
+ // \ (tuple10)-- /
+ // \ / \ /
+ // -----(GTE 0)-- --(GTE 1)----------
+ ComputationBuilder builder(client_, TestName());
+ std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
+ std::initializer_list<std::initializer_list<float>> constant_matrix = {
+ {1.f, 2.f, 3.f}, // row 0
+ {4.f, 5.f, 6.f}, // row 1
+ };
+ auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
+ builder.ConstantR2<float>(constant_matrix)});
+ auto new_tuple01 = builder.Tuple({builder.GetTupleElement(tuple_data, 0),
+ builder.GetTupleElement(tuple_data, 1)});
+ auto new_tuple10 = builder.Tuple({builder.GetTupleElement(tuple_data, 1),
+ builder.GetTupleElement(tuple_data, 0)});
+ auto vector_from_01 = builder.GetTupleElement(new_tuple01, 0);
+ auto vector_from_10 = builder.GetTupleElement(new_tuple10, 1);
+ auto matrix_from_01 = builder.GetTupleElement(new_tuple01, 1);
+ auto matrix_from_10 = builder.GetTupleElement(new_tuple10, 0);
+
+ auto addvectors = builder.Add(vector_from_01, vector_from_10);
+ auto addmatrices = builder.Add(matrix_from_01, matrix_from_10);
+
+ auto result = builder.Add(addmatrices, addvectors,
+ /*broadcast_dimensions=*/{1});
+
+ Array2D<float> expected({
+ {4.f, 8.f, 12.f}, // row 0
+ {10.f, 14.f, 18.f}, // row 1
+ });
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) {
+ // Tests a selection between tuples with "false" path taken.
+ ComputationBuilder builder(client_, TestName());
+
+ std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
+ std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
+ auto tuple12 = builder.Tuple(
+ {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)});
+ auto tuple21 = builder.Tuple(
+ {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
+
+ auto select =
+ builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
+ auto expected =
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec2).get(),
+ LiteralUtil::CreateR1<float>(vec1).get()});
+ ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+}
+
+XLA_TEST_F(TupleTest, TuplesInAMap) {
+ Computation tuple_computation;
+ {
+ // tuple_computation(x) = 100 * min(x, x^2) + max(x, x^2) using tuples.
+ //
+ // Need to put a select in there to prevent HLO-level optimizations from
+ // optimizing out the tuples.
+ ComputationBuilder b(client_, "sort_square");
+ auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto x2 = b.Mul(x, x);
+ auto x_smaller_tuple = b.Tuple({x, x2});
+ auto x2_smaller_tuple = b.Tuple({x2, x});
+ auto sorted = b.Select(b.Lt(x, x2), x_smaller_tuple, x2_smaller_tuple);
+ auto smaller = b.GetTupleElement(sorted, 0);
+ auto greater = b.GetTupleElement(sorted, 1);
+ b.Add(greater, b.Mul(b.ConstantR0<float>(100.0f), smaller));
+ auto computation_status = b.Build();
+ ASSERT_IS_OK(computation_status.status());
+ tuple_computation = computation_status.ConsumeValueOrDie();
+ }
+
+ ComputationBuilder b(client_, TestName());
+ auto input = b.ConstantR1<float>({-1.0f, 1.0f, 2.1f});
+ b.Map({input}, tuple_computation);
+ ComputeAndCompareR1<float>(&b, {-99.0f, 101.0f, 214.41f}, {}, error_spec_);
+}
+
+XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) {
+ // Tests a selection between tuples with "true" path taken.
+ ComputationBuilder builder(client_, TestName());
+
+ std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
+ std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
+ auto tuple12 = builder.Tuple(
+ {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)});
+ auto tuple21 = builder.Tuple(
+ {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
+
+ auto select =
+ builder.Select(builder.ConstantR0<bool>(true), tuple12, tuple21);
+ auto expected =
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec1).get(),
+ LiteralUtil::CreateR1<float>(vec2).get()});
+ ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+}
+
+XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
+ // Tests a selection between tuples but the final result is an element of the
+ // tuple, not the whole tuple.
+ ComputationBuilder builder(client_, TestName());
+
+ std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
+ std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
+ auto tuple12 = builder.Tuple(
+ {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)});
+ auto tuple21 = builder.Tuple(
+ {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
+
+ auto select =
+ builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
+ auto element = builder.GetTupleElement(select, 0);
+
+ ComputeAndCompareR1<float>(&builder, vec2, {}, error_spec_);
+}
+
+// Cascaded selects between tuple types.
+XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesCascaded)) {
+ //
+ // vec1 vec2 vec2 vec1
+ // | | | |
+ // | | | |
+ // (tuple 12) (tuple 21)
+ // \ /
+ // \ /
+ // \ /
+ // true -- --(GTE 0)--(select 1)
+ // \ / |
+ // (pred tuple)-- | --(GTE 0)--
+ // / \ V / \
+ // false -- --(GTE 1)--(select 2)-- --(add)
+ // / \ /
+ // / --(GTE 1)--
+ // /
+ // (tuple 21)
+ ComputationBuilder builder(client_, TestName());
+
+ std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
+ std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
+
+ auto pred_tuple = builder.Tuple(
+ {builder.ConstantR0<bool>(true), builder.ConstantR0<bool>(false)});
+ auto tuple12 = builder.Tuple(
+ {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)});
+ auto tuple21 = builder.Tuple(
+ {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
+
+ auto select1 =
+ builder.Select(builder.GetTupleElement(pred_tuple, 0), tuple12, tuple21);
+ auto select2 =
+ builder.Select(builder.GetTupleElement(pred_tuple, 1), tuple21, select1);
+ auto result = builder.Add(builder.GetTupleElement(select2, 0),
+ builder.GetTupleElement(select2, 1));
+
+ ComputeAndCompareR1<float>(&builder, {3.f, 6.f, 9.f}, {}, error_spec_);
+}
+
+XLA_TEST_F(TupleTest,
+ DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesReuseConstants)) {
+ // Similar to SelectBetweenTuples, but the constants are shared between the
+ // input tuples.
+ ComputationBuilder builder(client_, TestName());
+
+ std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
+ std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
+ auto c1 = builder.ConstantR1<float>(vec1);
+ auto c2 = builder.ConstantR1<float>(vec2);
+ auto tuple12 = builder.Tuple({c1, c2});
+ auto tuple21 = builder.Tuple({c2, c1});
+
+ auto select =
+ builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
+ auto expected =
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec2).get(),
+ LiteralUtil::CreateR1<float>(vec1).get()});
+ ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+}
+
+XLA_TEST_F(TupleTest, NestedTuples) {
+ ComputationBuilder builder(client_, TestName());
+ auto inner_tuple = builder.Tuple(
+ {builder.ConstantR1<float>({1.0, 2.0}), builder.ConstantR0<float>(42.0)});
+ auto outer_tuple =
+ builder.Tuple({inner_tuple, builder.ConstantR1<float>({22.0, 44.0})});
+
+ auto expected_v1 = LiteralUtil::CreateR1<float>({1.0, 2.0});
+ auto expected_s = LiteralUtil::CreateR0<float>(42.0);
+ auto expected_inner_tuple =
+ LiteralUtil::MakeTuple({expected_v1.get(), expected_s.get()});
+ auto expected_v2 = LiteralUtil::CreateR1<float>({22.0, 44.0});
+ auto expected =
+ LiteralUtil::MakeTuple({expected_inner_tuple.get(), expected_v2.get()});
+
+ ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+}
+
+XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
+ ComputationBuilder builder(client_, TestName());
+
+ Shape data_shape = ShapeUtil::MakeShape(F32, {3});
+ Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({data_shape, data_shape});
+ Shape outer_tuple_shape =
+ ShapeUtil::MakeTupleShape({inner_tuple_shape, data_shape});
+
+ auto input = builder.Parameter(0, outer_tuple_shape, "input");
+ auto gte0 = builder.GetTupleElement(input, 0);
+ auto gte1 = builder.GetTupleElement(gte0, 1);
+ builder.Add(gte1, builder.ConstantR1<float>({10.0, 11.0, 12.0}));
+
+ std::unique_ptr<GlobalData> data =
+ client_
+ ->TransferToServer(*LiteralUtil::MakeTuple({
+ LiteralUtil::MakeTuple(
+ {
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}).get(),
+ LiteralUtil::CreateR1<float>({4.0, 5.0, 6.0}).get(),
+ })
+ .get(),
+ LiteralUtil::CreateR1<float>({7.0, 8.0, 9.0}).get(),
+ }))
+ .ConsumeValueOrDie();
+
+ std::vector<GlobalData*> arguments = {data.get()};
+ const std::vector<float> expected = {4.0 + 10.0, 5.0 + 11.0, 6.0 + 12.0};
+ ComputeAndCompareR1<float>(&builder, expected, arguments, ErrorSpec(1e-5));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc
new file mode 100644
index 0000000000..fdbaa0d178
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/unary_op_test.cc
@@ -0,0 +1,179 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class UnaryOpTest : public ClientLibraryTestBase {
+ protected:
+ template <typename T>
+ T inf() {
+ return std::numeric_limits<T>::infinity();
+ }
+ template <typename T>
+ void AbsSize0TestHelper() {
+ ComputationBuilder builder(client_, TestName());
+ auto arg = builder.ConstantR1<T>({});
+ auto abs = builder.Abs(arg);
+
+ ComputeAndCompareR1<T>(&builder, {}, {});
+ }
+
+ template <typename T>
+ void AbsTestHelper() {
+ ComputationBuilder builder(client_, TestName());
+ auto arg = builder.ConstantR1<T>({-2, 25, 0, -123, inf<T>(), -inf<T>()});
+ auto abs = builder.Abs(arg);
+
+ ComputeAndCompareR1<T>(&builder, {2, 25, 0, 123, inf<T>(), inf<T>()}, {});
+ }
+
+ template <typename T>
+ void SignTestHelper() {
+ ComputationBuilder builder(client_, TestName());
+ auto arg = builder.ConstantR1<T>(
+ {-2, 25, 0, static_cast<T>(-0.0), -123, inf<T>(), -inf<T>()});
+ auto sign = builder.Sign(arg);
+
+ ComputeAndCompareR1<T>(&builder, {-1, 1, 0, 0, -1, 1, -1}, {});
+ }
+
+ template <typename T>
+ void SignAbsTestHelper() {
+ ComputationBuilder builder(client_, TestName());
+ auto arg = builder.ConstantR1<T>({-2, 25, 0, -123});
+ auto sign = builder.Sign(arg);
+ auto abs = builder.Abs(arg);
+ builder.Sub(builder.Mul(sign, abs), arg);
+
+ ComputeAndCompareR1<T>(&builder, {0, 0, 0, 0}, {});
+ }
+};
+
+template <>
+int UnaryOpTest::inf<int>() {
+ return 2147483647;
+}
+
+XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) {
+ AbsSize0TestHelper<int>();
+ AbsSize0TestHelper<float>();
+}
+
+TEST_F(UnaryOpTest, AbsTestR1) {
+ AbsTestHelper<int>();
+ AbsTestHelper<float>();
+}
+
+TEST_F(UnaryOpTest, AbsTestR0) {
+ ComputationBuilder builder(client_, TestName());
+ auto argi = builder.ConstantR0<int>(-5);
+ auto absi = builder.Abs(argi);
+ auto argf = builder.ConstantR0<float>(-3.0f);
+ auto absf = builder.Abs(argf);
+ auto argf0 = builder.ConstantR0<float>(-0.0f);
+ auto absf0 = builder.Abs(argf0);
+ builder.Add(absf0, builder.Add(absf, builder.ConvertElementType(
+ absi, PrimitiveType::F32)));
+
+ ComputeAndCompareR0<float>(&builder, 8.0f, {});
+}
+
+TEST_F(UnaryOpTest, SignTestR0) {
+ ComputationBuilder builder(client_, TestName());
+ auto argi = builder.ConstantR0<int>(-5);
+ auto absi = builder.Sign(argi);
+ auto argf = builder.ConstantR0<float>(-4.0f);
+ auto absf = builder.Sign(argf);
+ auto argf0 = builder.ConstantR0<float>(-0.0f);
+ auto absf0 = builder.Sign(argf0);
+ builder.Add(absf0, builder.Add(absf, builder.ConvertElementType(
+ absi, PrimitiveType::F32)));
+
+ ComputeAndCompareR0<float>(&builder, -2.0f, {});
+}
+
+TEST_F(UnaryOpTest, SignTestR1) {
+ SignTestHelper<int>();
+ SignTestHelper<float>();
+}
+
+TEST_F(UnaryOpTest, SignAbsTestR1) {
+ SignAbsTestHelper<int>();
+ SignAbsTestHelper<float>();
+}
+
+TEST_F(UnaryOpTest, UnsignedAbsTestR1) {
+ ComputationBuilder builder(client_, TestName());
+ auto arg = builder.ConstantR1<unsigned int>(
+ {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()});
+ auto abs = builder.Abs(arg);
+
+ ComputeAndCompareR1<unsigned int>(
+ &builder, {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()}, {});
+}
+
+TEST_F(UnaryOpTest, UnsignedSignTestR1) {
+ ComputationBuilder builder(client_, TestName());
+ auto arg = builder.ConstantR1<unsigned int>(
+ {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()});
+ auto sign = builder.Sign(arg);
+
+ ComputeAndCompareR1<unsigned int>(&builder, {1, 1, 0, 1, 1}, {});
+}
+
+TEST_F(UnaryOpTest, SignAbsTestR2) {
+ ComputationBuilder builder(client_, TestName());
+ auto arg = builder.ConstantR2<float>({{1.0, -2.0}, {-3.0, 4.0}});
+ auto sign = builder.Sign(arg);
+ auto abs = builder.Abs(arg);
+ builder.Sub(builder.Mul(sign, abs), arg);
+
+ ComputeAndCompareR2<float>(&builder, {{0, 0}, {0, 0}}, {});
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc
new file mode 100644
index 0000000000..7f3d7d9cb4
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc
@@ -0,0 +1,235 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <numeric>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class VecOpsReduceTest : public ClientLibraryTestBase {
+ public:
+ VecOpsReduceTest() : builder_(client_, TestName()) {}
+
+ ComputationDataHandle BuildSampleConstantCube() {
+ // clang-format off
+ Array3D<float> x3d({
+ {{1.0, 2.0, 3.0}, // | dim 1 // } plane 0 in dim 0
+ {4.0, 5.0, 6.0}}, // V // }
+ // ---- dim 2 ---->
+ {{1.0, 2.0, 3.0}, // } plane 1 in dim 0
+ {4.0, 5.0, 6.0}},
+ {{1.0, 2.0, 3.0}, // } plane 2 in dim 0
+ {4.0, 5.0, 6.0}}});
+ // clang-format on
+ return builder_.ConstantR3FromArray3D<float>(x3d);
+ }
+
+ ComputationBuilder builder_;
+ ErrorSpec errspec_{1e-3, 0};
+};
+
+TEST_F(VecOpsReduceTest, AddReduceR1F32) {
+ auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
+
+ auto x = builder_.ConstantR1<float>(
+ {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ auto add_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0});
+
+ ComputeAndCompareR0<float>(&builder_, -4.2f, {}, errspec_);
+}
+
+TEST_F(VecOpsReduceTest, AddReduceBigR1F32) {
+ auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
+
+ std::vector<float> input(3000);
+ std::iota(input.begin(), input.end(), 100.0f);
+
+ auto x = builder_.ConstantR1<float>(input);
+ auto add_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0});
+
+ float expected = std::accumulate(input.begin(), input.end(), 0.0f);
+ ComputeAndCompareR0<float>(&builder_, expected, {}, errspec_);
+}
+
+TEST_F(VecOpsReduceTest, MaxReduceR1F32) {
+ auto max_reducer = CreateScalarMax();
+
+ auto x = builder_.ConstantR1<float>(
+ {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ auto max_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), max_reducer,
+ /*dimensions_to_reduce=*/{0});
+
+ ComputeAndCompareR0<float>(&builder_, 2.6f, {}, errspec_);
+}
+
+TEST_F(VecOpsReduceTest, MaxReduceR1F32WithNontrivialInit) {
+ auto max_reducer = CreateScalarMax();
+
+ auto x = builder_.ConstantR1<float>(
+ {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ auto max_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(4.0f), max_reducer,
+ /*dimensions_to_reduce=*/{0});
+
+ ComputeAndCompareR0<float>(&builder_, 4.0f, {}, errspec_);
+}
+
+TEST_F(VecOpsReduceTest, AddReduceR2F32Dim1) {
+ auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
+
+ // clang-format off
+ auto x = builder_.ConstantR2<float>({
+ {1.0, 2.0, 3.0}, // | dim 0
+ {4.0, 5.0, 6.0}}); // |
+ // ------ dim 1 ----------
+ // clang-format on
+
+ auto add_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{1});
+
+ ComputeAndCompareR1<float>(&builder_, {6.0, 15.0}, {}, errspec_);
+}
+
+TEST_F(VecOpsReduceTest, AddReduceR2F32Dim0) {
+ auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
+
+ // clang-format off
+ auto x = builder_.ConstantR2<float>({
+ {1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0}});
+ // clang-format on
+ auto add_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0});
+
+ ComputeAndCompareR1<float>(&builder_, {5.0, 7.0, 9.0}, {}, errspec_);
+}
+
+TEST_F(VecOpsReduceTest, AddReduceR3F32Dim2) {
+ auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
+ auto x = BuildSampleConstantCube();
+ auto add_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{2});
+
+ Array2D<float> expected_array({{6.0f, 15.0f}, {6.0f, 15.0f}, {6.0f, 15.0f}});
+
+ ComputeAndCompareR2<float>(&builder_, expected_array, {}, errspec_);
+}
+
+TEST_F(VecOpsReduceTest, AddReduceR3F32Dim1) {
+ auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
+ auto x = BuildSampleConstantCube();
+ auto add_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{1});
+
+ Array2D<float> expected_array(
+ {{5.0f, 7.0f, 9.0f}, {5.0f, 7.0f, 9.0f}, {5.0f, 7.0f, 9.0f}});
+
+ ComputeAndCompareR2<float>(&builder_, expected_array, {}, errspec_);
+}
+
+TEST_F(VecOpsReduceTest, AddReduceR3F32Dim0) {
+ auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
+ auto x = BuildSampleConstantCube();
+ auto add_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0});
+
+ Array2D<float> expected_array({{3.0f, 6.0f, 9.0f}, {12.0f, 15.0f, 18.0f}});
+
+ ComputeAndCompareR2<float>(&builder_, expected_array, {}, errspec_);
+}
+
+TEST_F(VecOpsReduceTest, AddReduceR3F32Dims1and2) {
+ auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
+ auto x = BuildSampleConstantCube();
+ auto add_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{1, 2});
+
+ ComputeAndCompareR1<float>(&builder_, {21.0, 21.0, 21.0}, {}, errspec_);
+}
+
+XLA_TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and2) {
+ auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
+ auto x = BuildSampleConstantCube();
+ auto add_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0, 2});
+
+ ComputeAndCompareR1<float>(&builder_, {18.0, 45.0}, {}, errspec_);
+}
+
+TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and1) {
+ auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
+ auto x = BuildSampleConstantCube();
+ auto add_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0, 1});
+
+ ComputeAndCompareR1<float>(&builder_, {15.0, 21.0, 27.0}, {}, errspec_);
+}
+
+TEST_F(VecOpsReduceTest, AddReduceR3F32AllDims) {
+ auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
+ auto x = BuildSampleConstantCube();
+ auto add_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0, 1, 2});
+
+ ComputeAndCompareR0<float>(&builder_, 63.0, {}, errspec_);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
new file mode 100644
index 0000000000..d9fc1e1e8f
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
@@ -0,0 +1,423 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cmath>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class VecOpsSimpleTest : public ClientLibraryTestBase {
+ public:
+ explicit VecOpsSimpleTest(perftools::gputools::Platform* platform = nullptr)
+ : ClientLibraryTestBase(platform,
+ /*disabled_pass_names=*/{"algsimp", "inline"}) {}
+
+ ErrorSpec error_spec_{0.0001};
+};
+
+TEST_F(VecOpsSimpleTest, ExpTenValues) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<float>(
+ {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ auto exp = builder.Exp(x);
+
+ std::vector<float> expected = {8.1662, 7.4274e-02, 13.4637, 1.8316e-02,
+ 8.1662, 9.9742, 6.7379e-03, 4.0657e-01,
+ 9.0718e-02, 4.9530};
+
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(VecOpsSimpleTest, ExpManyValues) {
+ for (int count : {63, 64, 65, 127, 128, 129, 17 * 4096}) {
+ ComputationBuilder builder(client_, TestName());
+ std::vector<float> exponents;
+ for (int i = 0; i < count; ++i) {
+ exponents.push_back(i / static_cast<float>(count));
+ }
+ auto x = builder.ConstantR1<float>(exponents);
+ auto exp = builder.Exp(x);
+
+ std::vector<float> expected;
+ for (float exponent : exponents) {
+ expected.push_back(std::exp(exponent));
+ }
+
+ ComputeAndCompareR1<float>(&builder, expected, {},
+ ErrorSpec(/*aabs=*/1e-2, /*arel=*/1e-3));
+ }
+}
+
+TEST_F(VecOpsSimpleTest, ExpIn4D) {
+ ComputationBuilder builder(client_, TestName());
+ Array4D<float> exponents(2, 2, 2, 2);
+
+ std::vector<float> exponents_vector;
+ std::vector<float> expected_vector;
+ for (int i = 0; i < exponents.num_elements(); ++i) {
+ exponents_vector.push_back(static_cast<float>(i) /
+ exponents.num_elements());
+ expected_vector.push_back(std::exp(exponents_vector.back()));
+ }
+ exponents.SetValues(exponents_vector);
+
+ Array4D<float> expected(2, 2, 2, 2, expected_vector);
+
+ auto x = builder.ConstantR4FromArray4D<float>(exponents);
+ auto exp = builder.Exp(x);
+
+ ComputeAndCompareR4<float>(&builder, expected, {},
+ ErrorSpec(/*aabs=*/1e-2, /*arel=*/1e-3));
+}
+
+TEST_F(VecOpsSimpleTest, NegateTenFloatValues) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<float>(
+ {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ builder.Neg(x);
+
+ std::vector<float> expected = {-2.1, 2.6, -2.6, 4.0, -2.1,
+ -2.3, 5.0, 0.9, 2.4, -1.6};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(VecOpsSimpleTest, NegateTenInt32Values) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<int32>({2, -2, 12, -4, 5, 20, -15, 0, -2, 1});
+ builder.Neg(x);
+
+ std::vector<int> expected = {-2, 2, -12, 4, -5, -20, 15, 0, 2, -1};
+ ComputeAndCompareR1<int32>(&builder, expected, {});
+}
+
+TEST_F(VecOpsSimpleTest, NegateUint32Values) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<uint32>(
+ {0, 1, 42, static_cast<uint32>(-1), static_cast<uint32>(-12)});
+ builder.Neg(x);
+ std::vector<uint32> expected = {0, static_cast<uint32>(-1),
+ static_cast<uint32>(-42), 1, 12};
+ ComputeAndCompareR1<uint32>(&builder, expected, {});
+}
+
+TEST_F(VecOpsSimpleTest, SquareTenValues) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<float>(
+ {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ builder.SquareF32(x);
+
+ std::vector<float> expected = {4.41, 6.76, 6.76, 16., 4.41,
+ 5.29, 25., 0.81, 5.76, 2.56};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(VecOpsSimpleTest, ReciprocalTenValues) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<float>(
+ {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ builder.ReciprocalF32(x);
+
+ std::vector<float> expected = {
+ 0.47619048, -0.38461538, 0.38461538, -0.25, 0.47619048,
+ 0.43478261, -0.2, -1.11111111, -0.41666667, 0.625};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) {
+ ComputationBuilder builder(client_, TestName());
+ auto add = CreateScalarAddComputation(F32, &builder);
+
+ auto x = builder.ConstantR1<float>(
+ {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ auto y = builder.ConstantR1<float>(
+ {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6});
+ auto max = builder.Map({x, y}, add);
+
+ std::vector<float> expected = {1.7, -3.2, -0.4, -3.8, 5.9,
+ 0.1, -6.8, 4., -1., 2.2};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(VecOpsSimpleTest, MaxTenValues) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<float>(
+ {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ auto y = builder.ConstantR1<float>(
+ {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6});
+ auto max = builder.Max(x, y);
+
+ std::vector<float> expected = {2.1, -0.6, 2.6, 0.2, 3.8,
+ 2.3, -1.8, 4.9, 1.4, 1.6};
+ ComputeAndCompareR1<float>(&builder, expected, {});
+}
+
+TEST_F(VecOpsSimpleTest, MaxTenValuesFromParams) {
+ // Similar to MaxTenValues, except that the inputs come from params rather
+ // than constants.
+ ComputationBuilder builder(client_, TestName());
+ ComputationDataHandle v1, v2;
+ std::unique_ptr<GlobalData> param0_data = CreateR1Parameter<float>(
+ {41.0f, 2.0f, 3.0f, 84.0f}, /*parameter_number=*/0, /*name=*/"v1",
+ /*builder=*/&builder, /*data_handle=*/&v1);
+ std::unique_ptr<GlobalData> param1_data = CreateR1Parameter<float>(
+ {21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2",
+ /*builder=*/&builder, /*data_handle=*/&v2);
+
+ auto max = builder.Max(v1, v2);
+ ComputeAndCompareR1<float>(&builder, {41.0f, 22.0f, 23.0f, 84.0f},
+ {param0_data.get(), param1_data.get()},
+ error_spec_);
+}
+
+TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) {
+ // Similar to MaxTenValuesFromParams, except that the data size passed in and
+ // out is large.
+ ComputationBuilder builder(client_, TestName());
+
+ // Number of floats in the data passed into and out of the computation.
+ constexpr int datalen = 15 * 1000;
+
+ // The inputs are initialized with a special pattern where in the first third
+ // of the data v1[i] > v2[i] and elsewhere it's vice versa.
+ std::vector<float> v1vec;
+ std::vector<float> v2vec;
+ std::vector<float> expected_vec;
+ for (int i = 0; i < datalen; ++i) {
+ float smaller = i;
+ float larger = i * 2;
+ if (i < datalen / 3) {
+ v1vec.push_back(larger);
+ v2vec.push_back(smaller);
+ } else {
+ v1vec.push_back(smaller);
+ v2vec.push_back(larger);
+ }
+ expected_vec.push_back(larger);
+ }
+
+ ComputationDataHandle v1, v2;
+ std::unique_ptr<GlobalData> param0_data =
+ CreateR1Parameter<float>(v1vec, /*parameter_number=*/0, /*name=*/"v1",
+ /*builder=*/&builder, /*data_handle=*/&v1);
+ std::unique_ptr<GlobalData> param1_data =
+ CreateR1Parameter<float>(v2vec, /*parameter_number=*/1, /*name=*/"v2",
+ /*builder=*/&builder, /*data_handle=*/&v2);
+
+ auto max = builder.Max(v1, v2);
+ ComputeAndCompareR1<float>(&builder, expected_vec,
+ {param0_data.get(), param1_data.get()},
+ error_spec_);
+}
+
+TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<float>(
+ {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ auto y = builder.ConstantR0<float>(0);
+ auto max = builder.Max(x, y);
+
+ std::vector<float> expected = {2.1, 0.0, 2.6, 0.0, 2.1,
+ 2.3, 0.0, 0.0, 0.0, 1.6};
+ ComputeAndCompareR1<float>(&builder, expected, {});
+}
+
+TEST_F(VecOpsSimpleTest, MinTenValues) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<float>(
+ {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ auto y = builder.ConstantR1<float>(
+ {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6});
+ auto min = builder.Min(x, y);
+
+ std::vector<float> expected = {-0.4, -2.6, -3.0, -4.0, 2.1,
+ -2.2, -5.0, -0.9, -2.4, 0.6};
+ ComputeAndCompareR1<float>(&builder, expected, {});
+}
+
+TEST_F(VecOpsSimpleTest, MinMaxTenValues) {
+ ComputationBuilder builder(client_, TestName());
+ auto zero = builder.ConstantR0<float>(0);
+ auto one = builder.ConstantR0<float>(1);
+ auto x = builder.ConstantR1<float>(
+ {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6});
+ auto clamp = builder.Min(builder.Max(x, zero), one);
+
+ std::vector<float> expected = {1.0, 0.0, 1.0, 0.3, 1.0,
+ 0.9, 0.0, 0.1, 0.0, 0.6};
+ ComputeAndCompareR1<float>(&builder, expected, {});
+}
+
+TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) {
+ ComputationBuilder builder(client_, TestName());
+ auto zero = builder.ConstantR0<float>(0);
+ auto one = builder.ConstantR0<float>(1);
+ auto x = builder.ConstantR1<float>(
+ {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6});
+ auto clamp = builder.Clamp(zero, x, one);
+
+ std::vector<float> expected = {1.0, 0.0, 1.0, 0.3, 1.0,
+ 0.9, 0.0, 0.1, 0.0, 0.6};
+ ComputeAndCompareR1<float>(&builder, expected, {});
+}
+
+TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) {
+ ComputationBuilder builder(client_, TestName());
+ auto zero = builder.ConstantR1<float>({0.0f, 0.0f});
+ auto one = builder.ConstantR1<float>({1.0f, 1.0f});
+ auto x = builder.ConstantR1<float>({2.1, -2.6});
+ auto clamp = builder.Clamp(zero, x, one);
+
+ std::vector<float> expected = {1.0, 0.0};
+ ComputeAndCompareR1<float>(&builder, expected, {});
+}
+
+TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) {
+ ComputationBuilder builder(client_, TestName());
+ auto one = builder.ConstantR0<float>(1);
+ auto two = builder.ConstantR0<float>(2);
+ auto x = builder.ConstantR1<float>(
+ {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6});
+ auto clamp = builder.Clamp(one, x, two);
+
+ std::vector<float> expected = {2.0, 1.0, 2.0, 1.0, 2.0,
+ 1.0, 1.0, 1.0, 1.0, 1.0};
+ ComputeAndCompareR1<float>(&builder, expected, {});
+}
+
+TEST_F(VecOpsSimpleTest, MapTenValues) {
+ Computation add_half;
+ {
+ // add_half(x) = x + 0.5
+ ComputationBuilder builder(client_, "add_half");
+ auto x_value =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x_value");
+ auto half = builder.ConstantR0<float>(0.5);
+ builder.Add(x_value, half);
+ auto computation_status = builder.Build();
+ ASSERT_IS_OK(computation_status.status());
+ add_half = computation_status.ConsumeValueOrDie();
+ }
+
+ Computation clamp;
+ {
+ // clamp(y) = clamp<0,5>(y)
+ ComputationBuilder builder(client_, "clamp");
+ auto y_value =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y_value");
+ auto zero = builder.ConstantR0<float>(0.0);
+ auto clamped = builder.Clamp(zero, y_value, builder.ConstantR0<float>(5));
+ auto computation_status = builder.Build();
+ ASSERT_IS_OK(computation_status.status());
+ clamp = computation_status.ConsumeValueOrDie();
+ }
+
+ Computation mult_relu_add;
+ {
+ // mult_relu_add(z) = clamp(add_half(2 * max(z, 0)))
+ ComputationBuilder builder(client_, "mult_relu_add");
+ auto z_value =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "z_value");
+ auto zero = builder.ConstantR0<float>(0.0);
+ auto two = builder.ConstantR0<float>(2.0);
+ auto max = builder.Max(z_value, zero);
+ auto mult = builder.Mul(two, max);
+ auto inner = builder.Map({mult}, add_half);
+ builder.Map({inner}, clamp);
+ auto computation_status = builder.Build();
+ ASSERT_IS_OK(computation_status.status());
+ mult_relu_add = computation_status.ConsumeValueOrDie();
+ }
+
+ ComputationBuilder builder(client_, "map10");
+ {
+ auto x = builder.ConstantR1<float>(
+ {2.1, -21.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ auto activations = builder.Map({x}, mult_relu_add);
+ }
+
+ std::vector<float> expected = {4.7, 0.5, 5.0, 0.5, 4.7,
+ 5.0, 0.5, 0.5, 0.5, 3.7};
+ ComputeAndCompareR1<float>(&builder, expected, {});
+}
+
+XLA_TEST_F(VecOpsSimpleTest, RemainderTenValuesS32) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<int32>({-5, -4, -3, -2, -1, 0, 1, 2, 3, 4});
+ auto y = builder.ConstantR0<int32>(3);
+ builder.Rem(x, y);
+
+ std::vector<int32> expected = {-2, -1, 0, -2, -1, 0, 1, 2, 0, 1};
+ ComputeAndCompareR1<int32>(&builder, expected, {});
+}
+
+XLA_TEST_F(VecOpsSimpleTest, VectorPredicateEqual) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<bool>({false, true});
+ auto y = builder.ConstantR1<bool>({true, false});
+ builder.Eq(x, y);
+
+ std::array<bool, 2> expected = {{false, false}};
+ ComputeAndCompareR1<bool>(&builder, expected, {});
+}
+
+XLA_TEST_F(VecOpsSimpleTest, VectorPredicateNotEqual) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<bool>({false, true});
+ auto y = builder.ConstantR1<bool>({true, false});
+ builder.Ne(x, y);
+
+ std::array<bool, 2> expected = {{true, true}};
+ ComputeAndCompareR1<bool>(&builder, expected, {});
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
new file mode 100644
index 0000000000..7820bc363d
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -0,0 +1,395 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/platform_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+namespace {
+
+class WhileTest : public ClientLibraryTestBase {};
+
+// Tests a while node when the result type T is S32.
+//
+// int32 result = 0;
+// while (result < 5) {
+// result = result + 1;
+// }
+TEST_F(WhileTest, WhileWithScalarResult) {
+ auto result_shape = ShapeUtil::MakeShape(S32, {});
+
+ // Create a computation for the condition: repeat for 5 iterations.
+ Computation condition;
+ {
+ ComputationBuilder builder(client_, "condition");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ builder.Gt(builder.ConstantR0<int32>(5), prev);
+ condition = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a computation for the body: add 1 to the result variable.
+ Computation body;
+ {
+ ComputationBuilder builder(client_, "body");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ auto input = builder.ConstantR0<int32>(1);
+ auto result = builder.Add(input, prev);
+ body = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a While node with computations for the condition and the body.
+ ComputationBuilder builder(client_, TestName());
+ auto init = builder.ConstantR0<int32>(0);
+ auto result = builder.While(condition, body, init);
+ auto shape = builder.GetShape(result).ConsumeValueOrDie();
+
+ ComputeAndCompareR0<int32>(&builder, 5, {});
+}
+
+// Tests a while node when the result type T is a vector.
+//
+// All constants are chosen to produce exact results.
+// vector<float> result(0);
+// while (result.sum() < 15.5f) {
+// result = result + vector<float>(0);
+// }
+// TODO(b/29185393): does not terminate on CPU.
+TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) {
+ Shape result_shape = ShapeUtil::MakeShape(F32, {0});
+
+ // Create a computation for the reduction.
+ Computation add;
+ {
+ ComputationBuilder builder(client_, "add");
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
+ builder.Add(x, y);
+ add = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a computation for the condition.
+ // Repeat until the sum of the result vector is less than 15.5f.
+ Computation condition;
+ {
+ ComputationBuilder builder(client_, "condition");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add,
+ /*dimensions_to_reduce=*/{0});
+ auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum);
+ condition = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a computation for the body.
+ // Add a constant vector of 1.f to the result vector.
+ Computation body;
+ {
+ ComputationBuilder builder(client_, "body");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ auto input = builder.ConstantR1<float>({});
+ auto result = builder.Add(input, prev);
+ body = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a While node with computations for the condition and the body.
+ ComputationBuilder builder(client_, "while");
+ auto init = builder.ConstantR1<float>({});
+ auto result = builder.While(condition, body, init);
+ VLOG(2) << "while = " << ShapeUtil::HumanString(
+ *builder.GetShape(result).ConsumeValueOrDie());
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.0001));
+}
+
+// Tests a while node when the result type T is a vector.
+//
+// All constants are chosen to produce exact results.
+// vector<float> result(8, 0.0f);
+// while (result.sum() < 15.5f) {
+// result = result + vector<float>(8, 0.125f);
+// }
+TEST_F(WhileTest, WhileWithVectorResult) {
+ Shape result_shape = ShapeUtil::MakeShape(F32, {8});
+
+ // Create a computation for the reduction.
+ Computation add;
+ {
+ ComputationBuilder builder(client_, "add");
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
+ builder.Add(x, y);
+ add = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a computation for the condition.
+ // Repeat until the sum of the result vector is less than 5.5f.
+ Computation condition;
+ {
+ ComputationBuilder builder(client_, "condition");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add,
+ /*dimensions_to_reduce=*/{0});
+ auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum);
+ condition = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a computation for the body.
+ // Add a constant vector of 1.f to the result vector.
+ Computation body;
+ {
+ ComputationBuilder builder(client_, "body");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ auto input = builder.ConstantR1<float>(8, 0.125f);
+ auto result = builder.Add(input, prev);
+ body = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a While node with computations for the condition and the body.
+ ComputationBuilder builder(client_, "while");
+ auto init = builder.ConstantR1<float>(8, 0.f);
+ auto result = builder.While(condition, body, init);
+ VLOG(2) << "while = " << ShapeUtil::HumanString(
+ *builder.GetShape(result).ConsumeValueOrDie());
+
+ // Individual elements with increase by 1/8 each time through the loop, so
+ // the sum will increase by 1.0. It will first be >15.5 when the elements
+ // have all reached 2.0.
+ std::vector<float> expected = {2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+// Tests a while node when the result type T is a Tuple.
+//
+// tuple<int32, vector<float>> result(0, vector<float>(10, 0.0f));
+// while (get<0>(result) < 5) {
+// get<0>(result) = get<0>(result) + 1;
+// get<1>(result) = get<1>(result) + vector<float>(10, 1.0f);
+// }
+TEST_F(WhileTest, WhileWithTupleResult) {
+ std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
+ ShapeUtil::MakeShape(F32, {10})};
+ Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
+
+ // Create a computation for the condition.
+ // Repeat for 5 iterations.
+ Computation condition;
+ {
+ ComputationBuilder builder(client_, "condition");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ auto iteration = builder.GetTupleElement(prev, 0);
+ builder.Gt(builder.ConstantR0<int32>(5), iteration);
+ condition = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a computation for the body.
+ // Add 1 to the iteration variable and add a constant vector of 1.0f to
+ // the weight variable, both of which are tuple elements.
+ Computation body;
+ {
+ ComputationBuilder builder(client_, "body");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ auto iteration = builder.GetTupleElement(prev, 0);
+ auto weights = builder.GetTupleElement(prev, 1);
+ auto input = builder.ConstantR1<float>(10, 1.f);
+ auto new_weights = builder.Add(weights, input);
+ auto result = builder.Tuple(
+ {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
+ body = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a While node with computations for the condition and the body.
+ ComputationBuilder builder(client_, "while");
+ auto init = builder.Tuple(
+ {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
+ auto result = builder.While(condition, body, init);
+ VLOG(2) << "while = " << ShapeUtil::HumanString(
+ *builder.GetShape(result).ConsumeValueOrDie());
+
+ auto expected_counter = LiteralUtil::CreateR0<int32>(5);
+ auto expected_data = LiteralUtil::CreateR1<float>(
+ {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f});
+ auto expected =
+ LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
+ ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+}
+
+// Tests a while node when the result type T is a vector of S32.
+//
+// int32 result = (0, 0, 0, 0, 0, 0);
+// while (result[0] < count) {
+// result += (1, U[0, 100], U[0, 100], U[0, 100], U[0, 100], U[0, 100]);
+// }
+//
+// This test misuses a vector to represent a pair:
+// ((iteration, (random vector))).
+//
+// Note: this test currently only tests generating random values within a loop.
+// Per backend the values generated can be different as the different backends
+// use different random number generators.
+// TODO(b/32240857): Extend test to verify outputs.
+TEST_F(WhileTest, WhileWithPrngScalarResult) {
+ auto v6s32 = ShapeUtil::MakeShape(S32, {6});
+
+ // Create a computation for the condition: repeat for count iterations.
+ auto build_condition = [this, v6s32](int count) {
+ ComputationBuilder builder(client_, TestName());
+ auto prev = builder.Reshape(
+ builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}), {0}, {});
+ builder.Gt(builder.ConstantR0<int32>(count), prev);
+ return builder.Build().ConsumeValueOrDie();
+ };
+
+ // Create a computation for the body: add 1 to the result variable.
+ Computation body;
+ {
+ ComputationBuilder builder(client_, "body");
+ auto prev = builder.Parameter(0, v6s32, "prev");
+ auto inc = builder.ConcatInDim(
+ {builder.ConstantR1<int32>({1}),
+ builder.RngUniform(builder.ConstantR0<int32>(0),
+ builder.ConstantR0<int32>(100),
+ ShapeUtil::MakeShape(S32, {5}))},
+ 0);
+ auto result = builder.Add(inc, prev);
+ body = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a While node with computations for the condition and the body.
+ auto while_loop = [this, &body, build_condition](int count) {
+ ComputationBuilder builder(client_, TestName());
+ auto init = builder.ConstantR1<int32>({0, 0, 0, 0, 0, 0});
+ auto result = builder.While(build_condition(count), body, init);
+ auto shape = builder.GetShape(result).ConsumeValueOrDie();
+ return builder.Build();
+ };
+
+ for (int i = 1; i < 4; ++i) {
+ TF_ASSIGN_OR_ASSERT_OK(auto computation, while_loop(i));
+ TF_ASSIGN_OR_ASSERT_OK(auto result,
+ client_->ExecuteAndTransfer(computation, {}, nullptr,
+ nullptr, /*seed=*/65));
+ }
+}
+
+void BM_WhileLoop(int num_iters) {
+ // Benchmark a simple kernel to measure while loop overheads.
+ tensorflow::testing::StopTiming();
+
+ se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
+ auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
+ StreamExecutorMemoryAllocator allocator(platform, executors);
+ LocalClient* client =
+ ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
+
+ Shape loop_state_shape = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})});
+
+ // Create while condition computation with 'loop_limit'.
+ const int32 loop_limit = 100;
+ Computation condition;
+ {
+ ComputationBuilder builder(client, "condition");
+ auto prev = builder.Parameter(0, loop_state_shape, "prev");
+ auto iteration = builder.GetTupleElement(prev, 0);
+ builder.Lt(iteration, builder.ConstantR0<int32>(loop_limit));
+ condition = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create while body computation with unit loop increment.
+ Computation body;
+ {
+ ComputationBuilder builder(client, "body");
+ auto prev = builder.Parameter(0, loop_state_shape, "prev");
+ auto iteration = builder.GetTupleElement(prev, 0);
+ auto weights = builder.GetTupleElement(prev, 1);
+ auto one = builder.ConstantR0<int32>(1);
+ auto next_iteration = builder.Add(iteration, one);
+ auto one_vec = builder.ConstantR1<float>(10, 1.f);
+ auto new_weights = builder.Add(weights, one_vec);
+ auto result = builder.Tuple({next_iteration, new_weights});
+ body = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a While instruction.
+ ComputationBuilder builder(client, "while");
+ auto init = builder.Tuple(
+ {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
+ builder.While(condition, body, init);
+ auto computation = builder.Build().ConsumeValueOrDie();
+
+ // Run some warm-up executions.
+ LocalExecuteOptions options;
+ options.set_allocator(&allocator);
+ const int kWarmups = 2;
+ for (int i = 0; i < kWarmups; ++i) {
+ auto result = client->ExecuteLocally(computation, {}, options);
+ ASSERT_TRUE(result.ok());
+ }
+
+ // Run benchmark.
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i < num_iters; ++i) {
+ auto result = client->ExecuteLocally(computation, {}, options);
+ ASSERT_TRUE(result.ok());
+ }
+}
+
+// TODO(b/32470510): Benchmark fails on parallel CPU backend.
+#ifndef XLA_TEST_BACKEND_CPU_PARALLEL
+BENCHMARK(BM_WhileLoop);
+#endif
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ tensorflow::testing::RunBenchmarks();
+ return RUN_ALL_TESTS();
+}