aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc
blob: cca35316f0c472d2a17c466f8cd1af7f22575a8b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <utility>

#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/platform/test.h"

namespace xla {
namespace gpu {
namespace {

class GpuKernelTilingTest : public GpuCodegenTest {
 protected:
  GpuKernelTilingTest() {
    auto debug_options = HloTestBase::GetDebugOptionsForTest();
    config_.set_debug_options(debug_options);
    // Disable layout_assignment to use the preassigned layouts.
    debug_options.add_xla_disable_hlo_passes("layout_assignment");
  }
  HloModuleConfig config_;
};

TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) {
  const char *const kHloString = R"(
    HloModule unnested_transpose_1

    ENTRY unnested_transpose_1 {
      para0 = f16[32,3,64]{2,1,0} parameter(0)
      ROOT copy1 = f16[32,3,64]{1,0,2} copy(para0)
    })";

  // Check that a call to llvm.nvvm.barrier0 is generated.
  auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
  CompileAndVerifyIr(std::move(hlo_module),
                     R"(
; CHECK-LABEL: define void @copy
; CHECK: tail call void @llvm.nvvm.barrier0()
; CHECK: }
)",
                     /*match_optimized_ir=*/true);

  // Check that the kernel runs correctly.
  EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0}));
}

TEST_F(GpuKernelTilingTest, UnnestedTransposeWithSmallDimensionsNotTiled) {
  const char *const kHloString = R"(
    HloModule unnested_transpose_2

    ENTRY unnested_transpose_2 {
      para0 = f16[2,3,64]{2,1,0} parameter(0)
      ROOT copy1 = f16[2,3,64]{1,0,2} copy(para0)
    })";

  // Check that a call to llvm.nvvm.barrier0 is not generated.
  auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
  CompileAndVerifyIr(std::move(hlo_module),
                     R"(
; CHECK-LABEL: define void @copy
; CHECK-NOT: tail call void @llvm.nvvm.barrier0()
; CHECK: }
)",
                     /*match_optimized_ir=*/true);
}

TEST_F(GpuKernelTilingTest, SimpleFusionWithTransposeTiled) {
  const char *const kHloString = R"(
    HloModule multiple_output_fusion_1
    fused_computation.1 {
      param0 = f32[4,5,6,7,8]{4,3,2,1,0} parameter(0)
      copy = f32[4,5,6,7,8]{2,1,4,3,0} copy(param0)
      ROOT convert = f16[4,5,6,7,8]{2,1,4,3,0} convert(copy)
    }

    ENTRY copy_in_fusion_run_without_hlo_passes {
      para0 = f32[4,5,6,7,8]{4,3,2,1,0} parameter(0)
      ROOT fusion.1 = f16[4,5,6,7,8]{2,1,4,3,0} fusion(para0), kind=kLoop,
        calls=fused_computation.1
    })";

  // Check that a call to llvm.nvvm.barrier0 is generated.
  auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
  CompileAndVerifyIr(std::move(hlo_module),
                     R"(
; CHECK-LABEL: define void @fusion
; CHECK: tail call void @llvm.nvvm.barrier0()
; CHECK: }
)",
                     /*match_optimized_ir=*/true);

  // Check that the kernel runs correctly.
  EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0}));
}

TEST_F(GpuKernelTilingTest, MultipleOutputFusionWithOnePossibleTransposeTiled) {
  const char *const kHloString = R"(
    HloModule multiple_output_fusion_1
    fused_computation.1 {
      param0 = f16[8,31,31,65]{3,2,1,0} parameter(0)
      param1 = f16[8,31,31,65]{3,2,1,0} parameter(1)
      copy0 = f16[8,31,31,65]{2,1,3,0} copy(param0)
      copy1 = f16[8,31,31,65]{2,1,3,0} copy(param1)
      ROOT tuple1 = (f16[8,31,31,65]{2,1,3,0}, f16[8,31,31,65]{2,1,3,0})
        tuple(copy0, copy1)
    }

    ENTRY multiple_output_fusion_1 {
      para0 = f16[8,31,31,65]{3,2,1,0} parameter(0)
      para1 = f16[8,31,31,65]{3,2,1,0} parameter(1)
      ROOT fusion.1 = (f16[8,31,31,65]{2,1,3,0}, f16[8,31,31,65]{2,1,3,0})
        fusion(para0,para1), kind=kLoop, calls=fused_computation.1
    })";

  // Check that a call to llvm.nvvm.barrier0 is generated.
  auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
  CompileAndVerifyIr(std::move(hlo_module),
                     R"(
; CHECK-LABEL: define void @fusion
; CHECK: tail call void @llvm.nvvm.barrier0()
; CHECK: }
)",
                     /*match_optimized_ir=*/true);

  // Check that the kernel runs correctly.
  EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0}));
}

TEST_F(GpuKernelTilingTest,
       MultipleOutputFusionWithTwoPossibleTransposesNotTiled) {
  const char *const kHloString = R"(
    HloModule multiple_output_fusion_2
    fused_computation.1 {
      param0 = f16[8,31,31,65]{3,2,1,0} parameter(0)
      param1 = f16[8,31,31,65]{1,3,2,0} parameter(1)
      copy2 = f16[8,31,31,65]{2,1,3,0} copy(param0)
      copy3 = f16[8,31,31,65]{2,1,3,0} copy(param1)
      ROOT tuple1 = (f16[8,31,31,65]{2,1,3,0}, f16[8,31,31,65]{2,1,3,0})
        tuple(copy2, copy3)
    }

    ENTRY multiple_output_fusion_2 {
      para0 = f16[8,31,31,65]{3,2,1,0} parameter(0)
      para1 = f16[8,31,31,65]{1,3,2,0} parameter(1)
      ROOT fusion1 = (f16[8,31,31,65]{2,1,3,0}, f16[8,31,31,65]{2,1,3,0})
        fusion(para0,para1), kind=kLoop, calls=fused_computation.1
    })";

  // Check that a call to llvm.nvvm.barrier0 is not generated.
  auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie();
  CompileAndVerifyIr(std::move(hlo_module),
                     R"(
; CHECK-LABEL: define void @fusion
; CHECK-NOT: tail call void @llvm.nvvm.barrier0()
; CHECK: }
)",
                     /*match_optimized_ir=*/true);
}

}  // namespace
}  // namespace gpu
}  // namespace xla