aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/cuda_kernel_helper_test.cu.cc
blob: 732ed33ede17bc90d3301d3f1eee6302a96028d7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if GOOGLE_CUDA
#define EIGEN_USE_GPU

#include <numeric>
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"

#define CUDA_EXPECT_SUCCESS                                 \
  {                                                         \
    cudaDeviceSynchronize();                                \
    cudaError_t err = cudaGetLastError();                   \
    EXPECT_EQ(cudaSuccess, err) << cudaGetErrorString(err); \
  }

#define CUDA_ASSERT_SUCCESS                                 \
  {                                                         \
    cudaDeviceSynchronize();                                \
    cudaError_t err = cudaGetLastError();                   \
    ASSERT_EQ(cudaSuccess, err) << cudaGetErrorString(err); \
  }

namespace tensorflow {

namespace {

__global__ void SetOutbufZero(CudaLaunchConfig config, int* outbuf) {
  CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) { outbuf[x] = 0; }
}

// counting number of jobs by using atomic +1
__global__ void Count1D(CudaLaunchConfig config, int bufsize, int* outbuf) {
  CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) {
    if (x < 0) {  // x might overflow when testing extreme case
      break;
    }
    atomicAdd(&outbuf[x % bufsize], 1);
  }
}
__global__ void Count2D(Cuda2DLaunchConfig config, int bufsize, int* outbuf) {
  CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) {
    if (x < 0) {  // x might overflow when testing extreme case
      break;
    }
    CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) {
      if (y < 0) {  // y might overflow when testing extreme case
        break;
      }
      int idx = x * config.virtual_thread_count.y + y;
      atomicAdd(&outbuf[idx % bufsize], 1);
    }
  }
}
__global__ void Count3D(Cuda3DLaunchConfig config, int bufsize, int* outbuf) {
  CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) {
    if (x < 0) {  // x might overflow when testing extreme case
      break;
    }
    CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) {
      if (y < 0) {  // y might overflow when testing extreme case
        break;
      }
      CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count.z, Z) {
        if (z < 0) {  // z might overflow when testing extreme case
          break;
        }
        int idx =
            x * config.virtual_thread_count.y * config.virtual_thread_count.z +
            y * config.virtual_thread_count.z + z;
        atomicAdd(&outbuf[idx % bufsize], 1);
      }
    }
  }
}

__global__ void CudaShuffleGetSrcLaneTest(unsigned* failure_count) {
  unsigned lane_id = CudaLaneId();
  for (int width = warpSize; width > 1; width /= 2) {
    auto check_result = [&](const char* op_name, int param, unsigned actual,
                            unsigned expected) {
      if (actual != expected) {
        printf("Cuda%sGetSrcLane(%d, %d) for lane %d returned %d, not %d\n",
               op_name, param, width, lane_id, actual, expected);
        CudaAtomicAdd(failure_count, 1);
      }
    };
    for (int src_lane = -warpSize; src_lane <= warpSize; ++src_lane) {
      unsigned actual_lane = detail::CudaShuffleGetSrcLane(src_lane, width);
      unsigned expect_lane =
          CudaShuffleSync(kCudaWarpAll, lane_id, src_lane, width);
      check_result("Shuffle", src_lane, actual_lane, expect_lane);
    }
    for (unsigned delta = 0; delta <= warpSize; ++delta) {
      unsigned actual_lane = detail::CudaShuffleUpGetSrcLane(delta, width);
      unsigned expect_lane =
          CudaShuffleUpSync(kCudaWarpAll, lane_id, delta, width);
      check_result("ShuffleUp", delta, actual_lane, expect_lane);
    }
    for (unsigned delta = 0; delta <= warpSize; ++delta) {
      unsigned actual_lane = detail::CudaShuffleDownGetSrcLane(delta, width);
      unsigned expect_lane =
          CudaShuffleDownSync(kCudaWarpAll, lane_id, delta, width);
      check_result("ShuffleDown", delta, actual_lane, expect_lane);
    }
    for (int lane_lane = warpSize; lane_lane > 0; lane_lane /= 2) {
      unsigned actual_lane = detail::CudaShuffleXorGetSrcLane(lane_lane, width);
      unsigned expect_lane =
          CudaShuffleXorSync(kCudaWarpAll, lane_id, lane_lane, width);
      check_result("ShuffleXor", lane_lane, actual_lane, expect_lane);
    }
  }
}

}  // namespace

class CudaLaunchConfigTest : public ::testing::Test {
 protected:
  const int bufsize = 1024;
  int* outbuf = nullptr;
  Eigen::CudaStreamDevice stream;
  Eigen::GpuDevice d = Eigen::GpuDevice(&stream);

  virtual void SetUp() {
    cudaError_t err = cudaMallocManaged(&outbuf, sizeof(int) * bufsize);
    ASSERT_EQ(cudaSuccess, err) << cudaGetErrorString(err);
  }

  virtual void TearDown() {
    cudaDeviceSynchronize();
    cudaFree(outbuf);
    outbuf = nullptr;
  }
};

TEST_F(CudaLaunchConfigTest, GetCudaLaunchConfig) {
  CudaLaunchConfig cfg;

// test valid inputs
#define TEST_LAUNCH_PARAMETER(work_element_count)                              \
  cfg = GetCudaLaunchConfig(bufsize, d);                                       \
  SetOutbufZero<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>>(     \
      cfg, outbuf);                                                            \
  CUDA_ASSERT_SUCCESS                                                          \
  cfg = GetCudaLaunchConfig(work_element_count, d);                            \
  Count1D<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>>(           \
      cfg, bufsize, outbuf);                                                   \
  CUDA_EXPECT_SUCCESS                                                          \
  EXPECT_EQ(work_element_count, std::accumulate(outbuf, outbuf + bufsize, 0)); \
                                                                               \
  cfg = GetCudaLaunchConfig(bufsize, d, SetOutbufZero, 0, 0);                  \
  SetOutbufZero<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>>(     \
      cfg, outbuf);                                                            \
  CUDA_ASSERT_SUCCESS                                                          \
  cfg = GetCudaLaunchConfig(work_element_count, d, Count1D, 0, 0);             \
  Count1D<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>>(           \
      cfg, bufsize, outbuf);                                                   \
  CUDA_EXPECT_SUCCESS                                                          \
  EXPECT_EQ(work_element_count, std::accumulate(outbuf, outbuf + bufsize, 0))

  TEST_LAUNCH_PARAMETER(128);
  TEST_LAUNCH_PARAMETER(129);
  TEST_LAUNCH_PARAMETER(511);
  TEST_LAUNCH_PARAMETER(512);
  TEST_LAUNCH_PARAMETER(2048);
  TEST_LAUNCH_PARAMETER(2049);
  TEST_LAUNCH_PARAMETER(8191);
  TEST_LAUNCH_PARAMETER(8192);
  TEST_LAUNCH_PARAMETER(123456);
  TEST_LAUNCH_PARAMETER(1 << 30);
#undef TEST_LAUNCH_PARAMETER
}

bool operator==(const Cuda2DLaunchConfig& a, const Cuda2DLaunchConfig& b) {
  return a.thread_per_block.x == b.thread_per_block.x &&
         a.thread_per_block.y == b.thread_per_block.y &&
         a.thread_per_block.z == b.thread_per_block.z &&
         a.block_count.x == b.block_count.x &&
         a.block_count.y == b.block_count.y &&
         a.block_count.z == b.block_count.z &&
         a.thread_per_block.x == b.thread_per_block.x &&
         a.thread_per_block.y == b.thread_per_block.y &&
         a.thread_per_block.z == b.thread_per_block.z;
}

TEST_F(CudaLaunchConfigTest, GetCuda2DLaunchConfig) {
  Cuda2DLaunchConfig cfg;
  CudaLaunchConfig cfg1d;

// test valid inputs
#define TEST_LAUNCH_PARAMETER(dimx, dimy)                                      \
  cfg1d = GetCudaLaunchConfig(bufsize, d);                                     \
  SetOutbufZero<<<cfg1d.block_count, cfg1d.thread_per_block, 0, d.stream()>>>( \
      cfg1d, outbuf);                                                          \
  CUDA_ASSERT_SUCCESS                                                          \
  cfg = GetCuda2DLaunchConfig(dimx, dimy, d);                                  \
  Count2D<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>>(           \
      cfg, bufsize, outbuf);                                                   \
  CUDA_EXPECT_SUCCESS                                                          \
  EXPECT_EQ(dimx* dimy, std::accumulate(outbuf, outbuf + bufsize, 0));         \
                                                                               \
  cfg1d = GetCudaLaunchConfig(bufsize, d, SetOutbufZero, 0, 0);                \
  SetOutbufZero<<<cfg1d.block_count, cfg1d.thread_per_block, 0, d.stream()>>>( \
      cfg1d, outbuf);                                                          \
  CUDA_ASSERT_SUCCESS                                                          \
  cfg = GetCuda2DLaunchConfig(dimx, dimy, d, Count2D, 0, 0);                   \
  Count2D<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>>(           \
      cfg, bufsize, outbuf);                                                   \
  CUDA_EXPECT_SUCCESS                                                          \
  EXPECT_EQ(dimx* dimy, std::accumulate(outbuf, outbuf + bufsize, 0))

  TEST_LAUNCH_PARAMETER(128, 128);
  TEST_LAUNCH_PARAMETER(129, 64);
  TEST_LAUNCH_PARAMETER(511, 2048);
  TEST_LAUNCH_PARAMETER(512, 512);
  TEST_LAUNCH_PARAMETER(2048, 1024);
  TEST_LAUNCH_PARAMETER(2049, 32);
  TEST_LAUNCH_PARAMETER(8191, 1);
  TEST_LAUNCH_PARAMETER(8192, 10);
  TEST_LAUNCH_PARAMETER(123456, 12);
  TEST_LAUNCH_PARAMETER(1, 1 << 30);
  TEST_LAUNCH_PARAMETER(1 << 30, 1);
#undef TEST_LAUNCH_PARAMETER
}

TEST_F(CudaLaunchConfigTest, GetCuda3DLaunchConfig) {
  Cuda3DLaunchConfig cfg;
  CudaLaunchConfig cfg1d;

// test valid inputs
#define TEST_LAUNCH_PARAMETER(dimx, dimy, dimz)                                \
  cfg1d = GetCudaLaunchConfig(bufsize, d, SetOutbufZero, 0, 0);                \
  SetOutbufZero<<<cfg1d.block_count, cfg1d.thread_per_block, 0, d.stream()>>>( \
      cfg1d, outbuf);                                                          \
  CUDA_ASSERT_SUCCESS                                                          \
  cfg = GetCuda3DLaunchConfig(dimx, dimy, dimz, d, Count3D, 0, 0);             \
  Count3D<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>>(           \
      cfg, bufsize, outbuf);                                                   \
  CUDA_EXPECT_SUCCESS                                                          \
  EXPECT_EQ(dimx* dimy* dimz, std::accumulate(outbuf, outbuf + bufsize, 0))

  TEST_LAUNCH_PARAMETER(128, 128, 128);
  TEST_LAUNCH_PARAMETER(129, 64, 1024);
  TEST_LAUNCH_PARAMETER(511, 2048, 128);
  TEST_LAUNCH_PARAMETER(512, 512, 64);
  TEST_LAUNCH_PARAMETER(2048, 1024, 128);
  TEST_LAUNCH_PARAMETER(2049, 32, 1024);
  TEST_LAUNCH_PARAMETER(8191, 1, 1024);
  TEST_LAUNCH_PARAMETER(8192, 10, 32);
  TEST_LAUNCH_PARAMETER(123456, 12, 21);
  TEST_LAUNCH_PARAMETER(1, 1, 1 << 30);
  TEST_LAUNCH_PARAMETER(1, 1 << 30, 1);
  TEST_LAUNCH_PARAMETER(1 << 30, 1, 1);
#undef TEST_LAUNCH_PARAMETER
}

TEST(CudaDeviceFunctionsTest, ShuffleGetSrcLane) {
  unsigned* failure_count;
  ASSERT_EQ(cudaMallocManaged(&failure_count, sizeof(unsigned)), cudaSuccess);
  *failure_count = 0;
  CudaShuffleGetSrcLaneTest<<<1, 32>>>(failure_count);
  ASSERT_EQ(cudaDeviceSynchronize(), cudaSuccess);
  ASSERT_EQ(*failure_count, 0);
  cudaFree(failure_count);
}

}  // namespace tensorflow

#endif  // GOOGLE_CUDA