diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/tests')
13 files changed, 1445 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD new file mode 100644 index 0000000000..686c3c16c9 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -0,0 +1,223 @@ +# Description: GPU-specific XLA tests. For example, codegen tests that +# verify the IR emitted. +# +# TODO(jlebar): None of these tests actually use the GPU, so they should not +# need to run on machines with GPUs present. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = [":friends"]) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +cc_library( + name = "gpu_codegen_test", + testonly = True, + srcs = ["gpu_codegen_test.cc"], + hdrs = ["gpu_codegen_test.h"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service/gpu:gpu_executable", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "gpu_copy_test", + srcs = ["gpu_copy_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_ftz_test", + srcs = ["gpu_ftz_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_index_test", + srcs = ["gpu_index_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_infeed_test", + srcs = ["infeed_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_kernel_tiling_test", + srcs = ["gpu_kernel_tiling_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_ldg_test", + srcs = ["gpu_ldg_test.cc"], + tags = ["requires-gpu-sm35"], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_noalias_test", + srcs = ["gpu_noalias_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_fusion_test", + srcs = ["gpu_fusion_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_unrolling_test", + srcs = ["gpu_unrolling_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_alignment_test", + testonly = True, + srcs = ["gpu_alignment_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", + "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_alignment_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_alignment_test.cc new file mode 100644 index 0000000000..672c68e59b --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_alignment_test.cc @@ -0,0 +1,54 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <memory> +#include <utility> + +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuAlignmentTest : public GpuCodegenTest {}; + +TEST_F(GpuAlignmentTest, Test) { + const char* hlo_string = R"( +HloModule GpuAlignmentTest + +ENTRY main { + zero = f32[] constant(0) + tok = token[] after-all() + a = f32[100] parameter(0) + b_tup = (f32[200], token[]) infeed(tok) + b = f32[200] get-tuple-element(b_tup), index=0 + a_padded = f32[150] pad(a, zero), padding=0_50 + b_sliced = f32[150] slice(b), slice={[0:150]} + ROOT c = f32[150] add(a_padded, b_sliced) +} +)"; + + CompileAndVerifyIr(hlo_string, R"( +CHECK: @fusion(i8* align 64 dereferenceable(600) %alloc0, i8* align 16 dereferenceable(400) %alloc1, i8* align 64 dereferenceable(864) %temp_buf) +)"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc new file mode 100644 index 0000000000..4b8415fe91 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace gpu { + +std::unique_ptr<HloModule> GpuCodegenTest::CreateNewModuleWithFTZ(bool ftz) { + HloModuleConfig config; + auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); + debug_options.set_xla_gpu_ftz(ftz); + debug_options.set_xla_gpu_max_kernel_unroll_factor(1); + // TODO(b/38354253): Change tests to use Parameters instead of Constants. + debug_options.add_xla_disable_hlo_passes("constant_folding"); + config.set_debug_options(debug_options); + + return MakeUnique<HloModule>(TestName(), config); +} + +void GpuCodegenTest::CompileAndVerifyPtx(std::unique_ptr<HloModule> hlo_module, + const string& pattern) { + std::unique_ptr<Executable> executable = + std::move(CompileToExecutable(std::move(hlo_module)).ValueOrDie()); + string ptx_str = + std::string(static_cast<GpuExecutable*>(executable.get())->ptx()); + StatusOr<bool> filecheck_result = RunFileCheck(ptx_str, pattern); + ASSERT_TRUE(filecheck_result.ok()); + EXPECT_TRUE(filecheck_result.ValueOrDie()); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h new file mode 100644 index 0000000000..e4a3573bab --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TESTS_GPU_CODEGEN_TEST_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TESTS_GPU_CODEGEN_TEST_H_ + +#include <string> + +#include "tensorflow/compiler/xla/tests/llvm_irgen_test_base.h" + +namespace xla { +namespace gpu { + +// Tests that verify IR or PTX emitted by the GPU backend is as expected. +class GpuCodegenTest : public LlvmIrGenTestBase { + protected: + // Like HloTestBase::CreateNewModule(), with a flag for configuring the ftz + // option. + std::unique_ptr<HloModule> CreateNewModuleWithFTZ(bool ftz); + + // Compiles the given HLO module to PTX and verifies the PTX matches the given + // FileCheck pattern. (See http://llvm.org/docs/CommandGuide/FileCheck.html). + void CompileAndVerifyPtx(std::unique_ptr<HloModule> hlo_module, + const string& pattern); +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TESTS_GPU_CODEGEN_TEST_H_ diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc new file mode 100644 index 0000000000..ce69e058e6 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <memory> +#include <utility> + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { + +class GpuCopyTest : public GpuCodegenTest {}; + +// The GPU backend should not emit a copy kernel for the kCopy instruction in +// this test. Instead, it should generate a CopyThunk which invokes cuMemcpy at +// runtime. +TEST_F(GpuCopyTest, UseMemcpy) { + HloComputation::Builder builder(TestName()); + + std::unique_ptr<Literal> literal = + LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}); + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kCopy, constant)); + + std::unique_ptr<HloComputation> computation = builder.Build(); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(std::move(computation)); + + // There should not be any kernel prefixed "copy". + CompileAndVerifyIr(std::move(hlo_module), "; CHECK-NOT: define void @_copy", + /*match_optimized_ir=*/false); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc new file mode 100644 index 0000000000..177b94934c --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc @@ -0,0 +1,119 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" + +// Check that the ftz (flush denormals to zero) flag is reflected in PTX as +// expected. + +namespace xla { +namespace gpu { +namespace { + +class GpuFtzTest : public GpuCodegenTest { + public: + explicit GpuFtzTest(bool ftz) : ftz_(ftz) {} + + // Creates an HLO module that performs the given binary operation on some + // data. + std::unique_ptr<HloModule> CreateBinaryOpModule(HloOpcode op) { + HloComputation::Builder builder(TestName()); + + Shape param_shape = ShapeUtil::MakeShapeWithLayout( + F32, /*dimensions=*/{100, 100}, /*minor_to_major=*/{1, 0}); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /* parameter_number=*/0, param_shape, "x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /* parameter_number=*/1, param_shape, "y")); + builder.AddInstruction(HloInstruction::CreateBinary(param_shape, op, x, y)); + + auto hlo_module = CreateNewModuleWithFTZ(ftz_); + hlo_module->AddEntryComputation(builder.Build()); + return hlo_module; + } + + // Creates an HLO module that performs the given unary operation on some data. + std::unique_ptr<HloModule> CreateUnaryOpModule(HloOpcode op) { + HloComputation::Builder builder(TestName()); + + Shape param_shape = ShapeUtil::MakeShapeWithLayout( + F32, /*dimensions=*/{100, 100}, /*minor_to_major=*/{1, 0}); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /* parameter_number=*/0, param_shape, "x")); + builder.AddInstruction(HloInstruction::CreateUnary(param_shape, op, x)); + + auto hlo_module = CreateNewModuleWithFTZ(ftz_); + hlo_module->AddEntryComputation(builder.Build()); + return hlo_module; + } + + bool ftz_; +}; + +class GpuFtzEnabledTest : public GpuFtzTest { + public: + GpuFtzEnabledTest() : GpuFtzTest(/*ftz=*/true) {} +}; + +class GpuFtzDisabledTest : public GpuFtzTest { + public: + GpuFtzDisabledTest() : GpuFtzTest(/*ftz=*/false) {} +}; + +// Check that we emit mul.ftz.f32 when in ftz mode, and plain mul.f32 otherwise. +TEST_F(GpuFtzEnabledTest, MultiplyFtz) { + CompileAndVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"( + CHECK-NOT: mul.f32 + CHECK: mul.ftz.f32 + CHECK-NOT: mul.f32 + )"); +} +TEST_F(GpuFtzDisabledTest, MultiplyFtz) { + CompileAndVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"( + CHECK-NOT: mul.ftz.f32 + CHECK: mul.f32 + CHECK-NOT: mul.ftz.f32 + )"); +} + +// In NVPTX, exp(float) is implemented in libdevice, and consults __nvvm_reflect +// to determine whether or not ftz is enabled. The implementation uses two +// calls to ex2.approx. When ftz is on, we get two calls to the ftz version; +// when ftz is off, we get one call to the ftz version and one call to the +// regular version. +TEST_F(GpuFtzEnabledTest, ExpFtz) { + CompileAndVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"( + CHECK-NOT: ex2.approx.f32 + CHECK: ex2.approx.ftz.f32 + CHECK-NOT: ex2.approx.f32 + CHECK: ex2.approx.ftz.f32 + CHECK-NOT: ex2.approx.f32 + CHECK-NOT: ex2.approx.ftz.f32 + )"); +} + +TEST_F(GpuFtzDisabledTest, ExpFtz) { + CompileAndVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"( + CHECK-NOT: ex2.approx.f32 + CHECK-DAG: ex2.approx.ftz.f32 + CHECK-DAG: ex2.approx.f32 + CHECK-NOT: ex2.approx.f32 + CHECK-NOT: ex2.approx.ftz.f32 + )"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_fusion_test.cc new file mode 100644 index 0000000000..674b436a8e --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_fusion_test.cc @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <utility> + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuFusionTest : public GpuCodegenTest {}; + +TEST_F(GpuFusionTest, FusedReshape) { + const char* hlo_text = R"( + HloModule test_module + + fused_computation { + p0.param_0 = f32[4,1,1]{2,1,0} parameter(0) + p1.param_1 = f32[4,1]{1,0} parameter(1) + reshape = f32[4,1]{1,0} reshape(p0.param_0) + ROOT add = f32[4,1] add(reshape, p1.param_1) + } + + ENTRY BroadcastIntoAdd { + p0 = f32[4,1,1]{2,1,0} parameter(0) + p1 = f32[4,1]{1,0} parameter(1) + ROOT fusion = f32[4,1]{1,0} fusion(p0, p1), kind=kLoop, + calls=fused_computation + } +)"; + + CompileAndVerifyIr(hlo_text, + R"( +; CHECK-LABEL: @fusion +; CHECK: fadd +; CHECK: } + )"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc new file mode 100644 index 0000000000..e5958165ef --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc @@ -0,0 +1,147 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <memory> +#include <utility> + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { + +// This file tests the index expressions used to reference source tensors. When +// the destination tensor and source tensor have compatible shapes, the linear +// index is used to access the source tensor. Otherwise, dimensional indices +// computed from the linear index are used to access the source tensor. + +class GpuIndexTest : public GpuCodegenTest {}; + +TEST_F(GpuIndexTest, CompatibleUseLinearIndex) { + HloComputation::Builder builder(TestName()); + + auto param_shape = ShapeUtil::MakeShape(F32, {5, 7, 2}); + HloInstruction* param_x = builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "x")); + HloInstruction* param_y = builder.AddInstruction( + HloInstruction::CreateParameter(1, param_shape, "y")); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {5, 7, 2}), HloOpcode::kGe, param_x, param_y)); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(builder.Build()); + + // Check the optimized IR as the unoptimized IR contains dead udiv and urem. + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-NOT: udiv +; CHECK-NOT: urem + )", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshape) { + HloModuleConfig config; + config.set_debug_options(HloTestBase::GetDebugOptionsForTest()); + auto module = ParseHloString(R"( + HloModule test_module + + ENTRY CompatibleUseLinearIndexWithReshape { + x = f32[5,7,2]{2,1,0} parameter(0) + y = f32[5,14]{1,0} parameter(1) + reshape = f32[5,7,2]{2,1,0} reshape(y) + ROOT gte = pred[5,7,2]{2,1,0} greater-than-or-equal-to(x, reshape) + })", + config) + .ValueOrDie(); + + // Check the optimized IR as the unoptimized IR contains dead udiv and urem. + CompileAndVerifyIr(std::move(module), + R"( +; CHECK-NOT: udiv +; CHECK-NOT: urem + )", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshapeAndBroadcast) { + HloModuleConfig config; + config.set_debug_options(HloTestBase::GetDebugOptionsForTest()); + auto module = ParseHloString(R"( + HloModule test_module + + ENTRY CompatibleUseLinearIndexWithReshape { + x = f32[5,7,2]{2,1,0} parameter(0) + y = f32[14]{0} parameter(1) + reshape = f32[7,2]{1,0} reshape(y) + broadcast = f32[5,7,2]{2,1,0} broadcast(reshape), dimensions={1,2} + ROOT gte = pred[5,7,2]{2,1,0} greater-than-or-equal-to(x, broadcast) + })", + config) + .ValueOrDie(); + + // Check the optimized IR reuses the linear index by calculating modulo 14. + CompileAndVerifyIr(std::move(module), + R"( +; CHECK: %[[urem1:.*]] = urem i{{[0-9]*}} %[[linear_index:.*]], 14 +; CHECK: %[[bitcast:.*]] = bitcast i8 addrspace(1)* %[[alloc:.*]] to float addrspace(1)* +; CHECK: %[[idx1:.*]] = zext i{{[0-9]*}} %[[urem1]] to i64 +; CHECK: getelementptr inbounds float, float addrspace(1)* %[[bitcast]], i64 %[[idx1]] + )", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithSizeOneDimensions) { + HloModuleConfig config; + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_max_kernel_unroll_factor(1); + config.set_debug_options(debug_options); + + auto module = ParseHloString(R"( + HloModule test_module + + ENTRY CompatibleUseLinearIndexWithSizeOneDimensions { + x = f32[1,1024,1,256]{3,2,1,0} parameter(0) + ROOT y = f16[1,1024,1,256]{2,3,1,0} convert(x) + })", + config) + .ValueOrDie(); + + // Check that the unoptimized IR reuses the linear index. + CompileAndVerifyIr(std::move(module), + R"( +; CHECK-LABEL: @fusion +; CHECK: udiv i32 %[[linear_index:.*]], 262144 +; CHECK: %[[ld_addr:.*]] = getelementptr inbounds float, float* {{.*}}, i32 %[[linear_index]] +; CHECK: load float, float* %[[ld_addr]] +; CHECK: %[[st_addr:.*]] = getelementptr inbounds half, half* {{.*}}, i32 %[[linear_index]] +; CHECK: store half {{.*}}, half* %[[st_addr]] + )", + /*match_optimized_ir=*/false); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc new file mode 100644 index 0000000000..cca35316f0 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -0,0 +1,177 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <utility> + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuKernelTilingTest : public GpuCodegenTest { + protected: + GpuKernelTilingTest() { + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + config_.set_debug_options(debug_options); + // Disable layout_assignment to use the preassigned layouts. + debug_options.add_xla_disable_hlo_passes("layout_assignment"); + } + HloModuleConfig config_; +}; + +TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) { + const char *const kHloString = R"( + HloModule unnested_transpose_1 + + ENTRY unnested_transpose_1 { + para0 = f16[32,3,64]{2,1,0} parameter(0) + ROOT copy1 = f16[32,3,64]{1,0,2} copy(para0) + })"; + + // Check that a call to llvm.nvvm.barrier0 is generated. + auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @copy +; CHECK: tail call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); + + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0})); +} + +TEST_F(GpuKernelTilingTest, UnnestedTransposeWithSmallDimensionsNotTiled) { + const char *const kHloString = R"( + HloModule unnested_transpose_2 + + ENTRY unnested_transpose_2 { + para0 = f16[2,3,64]{2,1,0} parameter(0) + ROOT copy1 = f16[2,3,64]{1,0,2} copy(para0) + })"; + + // Check that a call to llvm.nvvm.barrier0 is not generated. + auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @copy +; CHECK-NOT: tail call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuKernelTilingTest, SimpleFusionWithTransposeTiled) { + const char *const kHloString = R"( + HloModule multiple_output_fusion_1 + fused_computation.1 { + param0 = f32[4,5,6,7,8]{4,3,2,1,0} parameter(0) + copy = f32[4,5,6,7,8]{2,1,4,3,0} copy(param0) + ROOT convert = f16[4,5,6,7,8]{2,1,4,3,0} convert(copy) + } + + ENTRY copy_in_fusion_run_without_hlo_passes { + para0 = f32[4,5,6,7,8]{4,3,2,1,0} parameter(0) + ROOT fusion.1 = f16[4,5,6,7,8]{2,1,4,3,0} fusion(para0), kind=kLoop, + calls=fused_computation.1 + })"; + + // Check that a call to llvm.nvvm.barrier0 is generated. + auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: tail call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); + + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0})); +} + +TEST_F(GpuKernelTilingTest, MultipleOutputFusionWithOnePossibleTransposeTiled) { + const char *const kHloString = R"( + HloModule multiple_output_fusion_1 + fused_computation.1 { + param0 = f16[8,31,31,65]{3,2,1,0} parameter(0) + param1 = f16[8,31,31,65]{3,2,1,0} parameter(1) + copy0 = f16[8,31,31,65]{2,1,3,0} copy(param0) + copy1 = f16[8,31,31,65]{2,1,3,0} copy(param1) + ROOT tuple1 = (f16[8,31,31,65]{2,1,3,0}, f16[8,31,31,65]{2,1,3,0}) + tuple(copy0, copy1) + } + + ENTRY multiple_output_fusion_1 { + para0 = f16[8,31,31,65]{3,2,1,0} parameter(0) + para1 = f16[8,31,31,65]{3,2,1,0} parameter(1) + ROOT fusion.1 = (f16[8,31,31,65]{2,1,3,0}, f16[8,31,31,65]{2,1,3,0}) + fusion(para0,para1), kind=kLoop, calls=fused_computation.1 + })"; + + // Check that a call to llvm.nvvm.barrier0 is generated. + auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: tail call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); + + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0})); +} + +TEST_F(GpuKernelTilingTest, + MultipleOutputFusionWithTwoPossibleTransposesNotTiled) { + const char *const kHloString = R"( + HloModule multiple_output_fusion_2 + fused_computation.1 { + param0 = f16[8,31,31,65]{3,2,1,0} parameter(0) + param1 = f16[8,31,31,65]{1,3,2,0} parameter(1) + copy2 = f16[8,31,31,65]{2,1,3,0} copy(param0) + copy3 = f16[8,31,31,65]{2,1,3,0} copy(param1) + ROOT tuple1 = (f16[8,31,31,65]{2,1,3,0}, f16[8,31,31,65]{2,1,3,0}) + tuple(copy2, copy3) + } + + ENTRY multiple_output_fusion_2 { + para0 = f16[8,31,31,65]{3,2,1,0} parameter(0) + para1 = f16[8,31,31,65]{1,3,2,0} parameter(1) + ROOT fusion1 = (f16[8,31,31,65]{2,1,3,0}, f16[8,31,31,65]{2,1,3,0}) + fusion(para0,para1), kind=kLoop, calls=fused_computation.1 + })"; + + // Check that a call to llvm.nvvm.barrier0 is not generated. + auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK-NOT: tail call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc new file mode 100644 index 0000000000..6c9ae7bada --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc @@ -0,0 +1,141 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests that we emit ld.global.nc (the PTX instruction corresponding to CUDA's +// __ldg builtin) for reads of buffers that don't change during a kernel's +// execution. + +#include <memory> +#include <utility> + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { + +class GpuLdgTest : public GpuCodegenTest {}; + +// Parameters are never overwritten, so parameter reads should get ld.global.nc +// reads. +TEST_F(GpuLdgTest, LdgForParamRead) { + HloComputation::Builder builder(TestName()); + + auto shape = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); + std::unique_ptr<HloComputation> computation = builder.Build(); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(std::move(computation)); + + CompileAndVerifyPtx(std::move(hlo_module), R"( + CHECK-NOT: ld.global.f32 + CHECK: ld.global.nc.f32 + )"); +} + +// Check that reading a buffer produced by a non-parameter HLO also results in +// ld.global.nc, if that buffer isn't modified within the instruction that reads +// it. +TEST_F(GpuLdgTest, LdgForNonParamRead) { + HloComputation::Builder builder(TestName()); + + auto shape = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); + HloInstruction* square = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, add, add)); + builder.AddInstruction(HloInstruction::CreateTuple({add, square})); + std::unique_ptr<HloComputation> computation = builder.Build(); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(std::move(computation)); + + CompileAndVerifyPtx(std::move(hlo_module), R"( + CHECK: { + CHECK-NOT: ld.global.f32 + CHECK: ld.global.nc.f32 + CHECK: } + )"); +} + +// Check that reading a buffer that's modified in-place does not produce +// ld.global.nc. +// +// We do this by creating a reduce that feeds into a sin. We don't currently +// fuse sin into reduce, and the sin is elementwise, so it reuses its input +// buffer as its output. +// +// It seems like a fair bet that we won't start fusing sin into the output of +// reduce in the foreseeable future. But if that turns out to be wrong, I give +// you, future reader, permission to delete this test. +TEST_F(GpuLdgTest, NoLdgWhenSharingBuffer) { + auto hlo_module = CreateNewModule(); + HloComputation::Builder builder(TestName()); + + HloComputation* reduce_computation; + { + auto embedded_builder = HloComputation::Builder("add"); + auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "lhs")); + auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "rhs")); + embedded_builder.AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); + reduce_computation = + hlo_module->AddEmbeddedComputation(embedded_builder.Build()); + } + + auto param_shape = ShapeUtil::MakeShape(F32, {2, 2}); + auto reduce_shape = ShapeUtil::MakeShape(F32, {2}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "x")); + HloInstruction* reduce = builder.AddInstruction(HloInstruction::CreateReduce( + reduce_shape, + builder.AddInstruction(HloInstruction::CreateBinary( + param_shape, HloOpcode::kAdd, param, param)), + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))), + {0}, reduce_computation)); + builder.AddInstruction( + HloInstruction::CreateUnary(reduce_shape, HloOpcode::kSin, reduce)); + + std::unique_ptr<HloComputation> computation = builder.Build(); + hlo_module->AddEntryComputation(std::move(computation)); + + CompileAndVerifyPtx(std::move(hlo_module), R"( + CHECK-LABEL: .entry sin + CHECK: { + CHECK-NOT: ld.global.nc.f32 + CHECK: ld.global.f32 + CHECK: } + )"); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc new file mode 100644 index 0000000000..c42e5704a4 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc @@ -0,0 +1,68 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <memory> +#include <utility> + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { + +class GpuNoAliasTest : public GpuCodegenTest {}; + +TEST_F(GpuNoAliasTest, Concat) { + HloComputation::Builder builder(TestName()); + + auto param_shape = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* param_x = builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "x")); + HloInstruction* param_y = builder.AddInstruction( + HloInstruction::CreateParameter(1, param_shape, "y")); + HloInstruction* concat = + builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(F32, {2, 4}), {param_x, param_y}, 1)); + builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(F32, {2, 6}), {concat, param_x}, 1)); + + std::unique_ptr<HloComputation> computation = builder.Build(); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(std::move(computation)); + + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK: %[[x_gep:.*]] = getelementptr inbounds [2 x [2 x float]], [2 x [2 x float]]* %x{{.*}}, i32 0 +; CHECK: load float, float* %[[x_gep]], {{.*}}, !noalias ![[param_noalias:.*]] +; CHECK: %[[y_gep:.*]] = getelementptr inbounds [2 x [2 x float]], [2 x [2 x float]]* %y{{.*}}, i32 0 +; CHECK: load float, float* %[[y_gep]], {{.*}}, !noalias ![[param_noalias]] +; CHECK: %[[result_ptr:.*]] = bitcast [2 x [6 x float]]* %fusion{{.*}} to float* +; CHECK: %[[result_gep:.*]] = getelementptr inbounds float, float* %[[result_ptr]] +; CHECK: store float {{.*}}, float* %[[result_gep]], !alias.scope ![[param_noalias]] +; CHECK: ![[param_noalias]] = !{![[retval_buffer:.*]]} + )", + /*match_optimized_ir=*/false); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc new file mode 100644 index 0000000000..9622936306 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc @@ -0,0 +1,185 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <utility> + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuUnrollingTest : public GpuCodegenTest {}; + +const char *const kAddModule = R"( + HloModule test_module + + fused_computation { + p0.param_0 = f32[2,2]{1,0} parameter(0) + p1.param_1 = f32[2,2]{1,0} parameter(1) + ROOT add = f32[2,2] add(p0.param_0, p1.param_1) + } + + ENTRY BroadcastIntoAdd { + p0 = f32[2,2]{1,0} parameter(0) + p1 = f32[2,2]{1,0} parameter(1) + ROOT fusion = f32[2,2]{1,0} fusion(p0, p1), kind=kLoop, + calls=fused_computation + })"; + +TEST_F(GpuUnrollingTest, DoNotUnroll) { + HloModuleConfig config; + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_max_kernel_unroll_factor(1); + config.set_debug_options(debug_options); + auto hlo_module = ParseHloString(kAddModule, config).ValueOrDie(); + + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: @fusion +; CHECK: fadd +; CHECK-NOT: fadd +; CHECK: } + )", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuUnrollingTest, UnrollFourTimes) { + HloModuleConfig config; + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + // We request a factor of 8, but the computation works on 4 elements, limiting + // the maximum unroll factor. + debug_options.set_xla_gpu_max_kernel_unroll_factor(8); + config.set_debug_options(debug_options); + auto hlo_module = ParseHloString(kAddModule, config).ValueOrDie(); + + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: @fusion +; CHECK: fadd +; CHECK: fadd +; CHECK: fadd +; CHECK: fadd +; CHECK-NOT: fadd +; CHECK: } + )", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuUnrollingTest, UnrollDefaultTimes) { + // The default unrolling factor is 4. + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + auto hlo_module = ParseHloString(kAddModule, config).ValueOrDie(); + + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: @fusion +; CHECK: load <4 x float> +; CHECK: fadd +; CHECK: fadd +; CHECK: fadd +; CHECK: fadd +; CHECK-NOT: fadd +; CHECK: store <4 x float> +; CHECK: } + )", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuUnrollingTest, UnrollUnfusedAdd) { + HloModuleConfig config; + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_max_kernel_unroll_factor(4); + config.set_debug_options(debug_options); + + const char *const kUnfusedAddModule = R"( + HloModule test_module + + ENTRY AddFunc { + p0 = f32[2,2]{1,0} parameter(0) + p1 = f32[2,2]{1,0} parameter(1) + ROOT add = f32[2,2]{1,0} add(p0, p1) + })"; + auto hlo_module = ParseHloString(kUnfusedAddModule, config).ValueOrDie(); + + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: @add +; CHECK: load <4 x float> +; CHECK: fadd +; CHECK: fadd +; CHECK: fadd +; CHECK: fadd +; CHECK-NOT: fadd +; CHECK: store <4 x float> +; CHECK: } + )", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuUnrollingTest, UnrollMultiOutputFusion) { + HloModuleConfig config; + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_max_kernel_unroll_factor(2); + config.set_debug_options(debug_options); + + const char *const kMultiOutputFusionModule = R"( + HloModule test_module + + fused_computation { + p0.param_0 = f32[2,2]{1,0} parameter(0) + p1.param_1 = f32[2,2]{1,0} parameter(1) + add = f32[2,2]{1,0} add(p0.param_0, p1.param_1) + mul = f32[2,2]{1,0} multiply(p0.param_0, p1.param_1) + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(add, mul) + } + + ENTRY BroadcastIntoAdd { + p0 = f32[2,2]{1,0} parameter(0) + p1 = f32[2,2]{1,0} parameter(1) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p0, p1), kind=kLoop, + calls=fused_computation + })"; + auto hlo_module = + ParseHloString(kMultiOutputFusionModule, config).ValueOrDie(); + + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: @fusion +; CHECK: load <2 x float> +; CHECK: load <2 x float> +; CHECK-NOT: load <2 x float> +; CHECK: fadd +; CHECK: fmul +; CHECK: fadd +; CHECK: fmul +; CHECK: store <2 x float> +; CHECK: store <2 x float> +; CHECK-NOT: store <2 x float> +; CHECK-NOT: fadd +; CHECK-NOT: fmul +; CHECK: } + )", + /*match_optimized_ir=*/true); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc new file mode 100644 index 0000000000..ba5cd2d84d --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc @@ -0,0 +1,121 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <unistd.h> +#include <memory> + +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class InfeedTest : public ClientLibraryTestBase { + protected: + // Transfers the given literal to the infeed interface of the device, and + // check if the returned data from Infeed HLO is same as the literal. + void TestInfeedRoundTrip(const Literal& literal) { + // TODO(b/30481585) Explicitly reset the Infeed state so that the + // test is not affected by the state from the previous tests. + ASSERT_IS_OK(client_->TransferToInfeed(literal)); + XlaBuilder builder(TestName()); + Infeed(&builder, literal.shape()); + if (ShapeUtil::IsTuple(literal.shape())) { + // TODO(b/30609564): Use ComputeAndCompareLiteral instead. + ComputeAndCompareTuple(&builder, literal, {}); + } else { + ComputeAndCompareLiteral(&builder, literal, {}); + } + } +}; + +TEST_F(InfeedTest, SingleInfeedR0Bool) { + TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true)); +} + +TEST_F(InfeedTest, SingleInfeedR1U32) { + TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3})); +} + +TEST_F(InfeedTest, SingleInfeedR2F32) { + TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); +} + +TEST_F(InfeedTest, SingleInfeedR3F32) { + TestInfeedRoundTrip( + *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); +} + +TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) { + const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2}); + const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0}); + + TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, + r3_dim0minor)); + + TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, + r3_dim0major)); +} + +TEST_F(InfeedTest, SingleInfeedR4S32) { + TestInfeedRoundTrip(*LiteralUtil::CreateR4( + {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, + {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); +} + +// Tests that a large infeed can be handled. +TEST_F(InfeedTest, LargeInfeed) { + Array4D<float> array(80, 100, 8, 128); + array.FillIota(1.0f); + TestInfeedRoundTrip(*LiteralUtil::CreateR4FromArray4D<float>(array)); +} + +TEST_F(InfeedTest, SingleInfeedTuple) { + TestInfeedRoundTrip( + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(), + LiteralUtil::CreateR0<bool>(false).get()})); +} + +TEST_F(InfeedTest, SingleInfeedEmptyTuple) { + TestInfeedRoundTrip(*LiteralUtil::MakeTuple({})); +} + +// Tests that a large tuple infeed can be handled. +TEST_F(InfeedTest, SingleInfeedLargeTuple) { + Array4D<float> array(40, 100, 8, 128); + array.FillIota(1.0f); + TestInfeedRoundTrip(*LiteralUtil::MakeTuple( + {LiteralUtil::CreateR4FromArray4D<float>(array).get(), + LiteralUtil::CreateR0<int32>(5).get()})); +} + +} // namespace +} // namespace xla |