aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/cuda_kernel_helper.h
blob: ccee269eb3e9306f131f5fb46d125145d47a37b1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_
#define TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_

#if GOOGLE_CUDA

#include <algorithm>

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/platform/types.h"

#define CUDA_1D_KERNEL_LOOP(i, n)                            \
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
       i += blockDim.x * gridDim.x)

namespace tensorflow {

typedef Eigen::GpuDevice GPUDevice;

struct CudaLaunchConfig {
  // Logical number of thread that works on the elements. If each logical
  // thread works on exactly a single element, this is the same as the working
  // element count.
  int virtual_thread_count = -1;
  // Number of threads per block.
  int thread_per_block = -1;
  // Number of blocks for Cuda kernel launch.
  int block_count = -1;
};

// Calculate the Cuda launch config we should use for a kernel launch.
// This is assuming the kernel is quite simple and will largely be
// memory-limited.
inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
                                            const GPUDevice& d) {
  const int virtual_thread_count = work_element_count;
  const int physical_thread_count = std::min(
      d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(),
      virtual_thread_count);
  const int thread_per_block = std::min(1024, d.maxCudaThreadsPerBlock());
  const int block_count = std::min(
      (physical_thread_count + thread_per_block - 1) / thread_per_block,
      d.getNumCudaMultiProcessors());

  CudaLaunchConfig config;
  config.virtual_thread_count = virtual_thread_count;
  config.thread_per_block = thread_per_block;
  config.block_count = block_count;
  return config;
}

// Calculate the Cuda launch config we should use for a kernel launch. This
// variant takes the resource limits of func into account to maximize occupancy.
template <typename DeviceFunc>
inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
                                            const GPUDevice& d, DeviceFunc func,
                                            size_t dynamic_shared_memory_size,
                                            int block_size_limit) {
  int block_count = 0;
  int thread_per_block = 0;
  cudaOccupancyMaxPotentialBlockSize(&block_count, &thread_per_block, func,
                                     dynamic_shared_memory_size,
                                     block_size_limit);
  block_count =
      std::min(block_count,
               (work_element_count + thread_per_block - 1) / thread_per_block);

  CudaLaunchConfig config;
  config.virtual_thread_count = work_element_count;
  config.thread_per_block = thread_per_block;
  config.block_count = block_count;
  return config;
}

struct Cuda2DLaunchConfig {
  dim3 virtual_thread_count;
  dim3 thread_per_block;
  dim3 block_count;
};

inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim,
                                                const GPUDevice& d) {
  Cuda2DLaunchConfig config;

  config.virtual_thread_count = dim3(xdim, ydim, 1);

  const int kThreadsPerBlock = 256;
  int block_cols = std::min(xdim, kThreadsPerBlock);
  // ok to round down here and just do more loops in the kernel
  int block_rows = std::max(kThreadsPerBlock / block_cols, 1);

  const int physical_thread_count =
      d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor();

  const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1);

  config.thread_per_block = dim3(block_cols, block_rows, 1);

  int grid_x = std::min((xdim + block_cols - 1) / block_cols, max_blocks);

  config.block_count = dim3(
      grid_x, std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)), 1);

  return config;
}

namespace gpu {

template <typename IntType>
__device__ IntType upper_bound(IntType* first, IntType count, IntType val) {
  IntType* orig = first;
  IntType* it = nullptr;
  IntType step = 0;
  while (count > 0) {
    it = first;
    step = count / 2;
    it += step;
    if (!(val < *it)) {
      first = ++it;
      count -= step + 1;
    } else {
      count = step;
    }
  }

  return first - orig;
}

}  // namespace gpu

template <typename T>
__device__ __host__ inline T ldg(const T* address) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
  return __ldg(address);
#else
  return *address;
#endif
}

template <>
__device__ __host__ inline std::complex<float> ldg(
    const std::complex<float>* address) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
  float2 mem = __ldg(reinterpret_cast<const float2*>(address));
  return std::complex<float>(mem.x, mem.y);
#else
  return *address;
#endif
}

template <>
__device__ __host__ inline std::complex<double> ldg(
    const std::complex<double>* address) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
  double2 mem = __ldg(reinterpret_cast<const double2*>(address));
  return std::complex<double>(mem.x, mem.y);
#else
  return *address;
#endif
}

// CUDA provides atomic ops, but not for all types.  We provide wrappers
// for some ops and provide implementation for all reasonable types.
#define CUDA_ATOMIC_WRAPPER(op, T) \
  __device__ __forceinline__ T CudaAtomic##op(T* address, T val)

#define USE_CUDA_ATOMIC(op, T) \
  CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }

// For atomicAdd.
USE_CUDA_ATOMIC(Add, int32);
USE_CUDA_ATOMIC(Add, uint32);
USE_CUDA_ATOMIC(Add, uint64);
USE_CUDA_ATOMIC(Add, float);

// For atomicMax.
USE_CUDA_ATOMIC(Max, int32);
USE_CUDA_ATOMIC(Max, uint32);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
USE_CUDA_ATOMIC(Max, uint64);
#else
// The uint64 overload of atomicMax() is only available for __CUDA_ARCH__ >=
// 350.  If not satisfied, we provide a custom implementation using atomicCAS().
CUDA_ATOMIC_WRAPPER(Max, uint64) {
  uint64* address_as_ull = reinterpret_cast<uint64*>(address);
  uint64 old = *address_as_ull, assumed;

  do {
    assumed = old;
    old = atomicCAS(address_as_ull, assumed, max(val, assumed));
  } while (assumed != old);

  return old;
}
#endif

// Custom implementation of atomicAdd for double.
// This implementation is copied from CUDA manual.
CUDA_ATOMIC_WRAPPER(Add, double) {
  uint64* address_as_ull = reinterpret_cast<uint64*>(address);
  uint64 old = *address_as_ull, assumed;

  do {
    assumed = old;
    old = atomicCAS(address_as_ull, assumed,
                    __double_as_longlong(val + __longlong_as_double(assumed)));

    // Note: uses integer comparison to avoid hang in case of NaN
  } while (assumed != old);

  return __longlong_as_double(old);
}

// Helper functions for CudaAtomicAdd(half*, half), below.
//
// Note that if __CUDA_ARCH__ >= 530, we could probably use __hadd2()
// for a more efficient implementation, assuming that adding -0.0
// will never harm the neighboring value. In this version, we take special
// care to guarantee the bits of the untouched value are unchanged.
inline __device__ uint32 add_to_low_half(uint32 val, float x) {
  Eigen::half low_half;
  low_half.x = static_cast<uint16>(val & 0xffffu);
  low_half = static_cast<Eigen::half>(static_cast<float>(low_half) + x);
  return (val & 0xffff0000u) | low_half.x;
}

inline __device__ uint32 add_to_high_half(uint32 val, float x) {
  Eigen::half high_half;
  high_half.x = static_cast<uint16>(val >> 16);
  high_half = static_cast<Eigen::half>(static_cast<float>(high_half) + x);
  return (val & 0xffffu) | (high_half.x << 16);
}

// Custom implementation of atomicAdd for half. Note that we don't have
// atomicCAS() for anything less than 32 bits, so we need to include the
// other 16 bits in the operation.
//
// Unlike the other atomic adds, this version is going to be very slow
// under high concurrency, since most threads will be spinning on failing
// their compare-and-swap tests. (The fact that we get false sharing on the
// neighboring fp16 makes this even worse.) If you are doing a large reduction,
// you are much better off with doing the intermediate steps in fp32 and then
// switching to fp16 as late as you can in the calculations.
//
// Note: Assumes little endian.
CUDA_ATOMIC_WRAPPER(Add, Eigen::half) {
  float val_as_float(val);
  intptr_t address_int = reinterpret_cast<intptr_t>(address);
  if ((address_int & 0x2) == 0) {
    // The half is in the first part of the uint32 (lower 16 bits).
    uint32* address_as_uint32 = reinterpret_cast<uint32*>(address);
    assert(((intptr_t)address_as_uint32 & 0x3) == 0);
    uint32 old = *address_as_uint32, assumed;

    do {
      assumed = old;
      old = atomicCAS(address_as_uint32, assumed,
                      add_to_low_half(assumed, val_as_float));

      // Note: uses integer comparison to avoid hang in case of NaN
    } while (assumed != old);

    Eigen::half ret;
    ret.x = old & 0xffffu;
    return ret;
  } else {
    // The half is in the second part of the uint32 (upper 16 bits).
    uint32* address_as_uint32 = reinterpret_cast<uint32*>(address_int - 2);
    assert(((intptr_t)address_as_uint32 & 0x3) == 0);
    uint32 old = *address_as_uint32, assumed;

    do {
      assumed = old;
      old = atomicCAS(address_as_uint32, assumed,
                      add_to_high_half(assumed, val_as_float));

      // Note: uses integer comparison to avoid hang in case of NaN
    } while (assumed != old);

    Eigen::half ret;
    ret.x = old >> 16;
    return ret;
  }
}

template <typename T>
__global__ void SetZero(const int nthreads, T* bottom_diff) {
  CUDA_1D_KERNEL_LOOP(index, nthreads) { *(bottom_diff + index) = T(0); }
}

// For atomicSub.

// Custom implementation for sub by just negating the value.
#define WRAPPED_ATOMIC_SUB(T) \
  CUDA_ATOMIC_WRAPPER(Sub, T) { return CudaAtomicAdd(address, -val); }

WRAPPED_ATOMIC_SUB(uint64);
WRAPPED_ATOMIC_SUB(int32);
WRAPPED_ATOMIC_SUB(uint32);
WRAPPED_ATOMIC_SUB(float);
WRAPPED_ATOMIC_SUB(double);

#undef WRAPPED_ATOMIC_SUB

// For atomicMul.
CUDA_ATOMIC_WRAPPER(Mul, int32) {
  int32 old = *address, assumed;
  do {
    assumed = old;
    old = atomicCAS(address, assumed, val * assumed);
  } while (assumed != old);
  return old;
}

CUDA_ATOMIC_WRAPPER(Mul, uint32) {
  uint32 old = *address, assumed;
  do {
    assumed = old;
    old = atomicCAS(address, assumed, val * assumed);
  } while (assumed != old);
  return old;
}

CUDA_ATOMIC_WRAPPER(Mul, uint64) {
  uint64 old = *address, assumed;
  do {
    assumed = old;
    old = atomicCAS(address, assumed, val * assumed);
  } while (assumed != old);
  return old;
}

CUDA_ATOMIC_WRAPPER(Mul, float) {
  int32* address_as_int = reinterpret_cast<int32*>(address);
  int32 old = *address_as_int, assumed;
  do {
    assumed = old;
    old = atomicCAS(address_as_int, assumed,
                    __float_as_int(val * __int_as_float(assumed)));
  } while (assumed != old);
  return __int_as_float(old);
}

CUDA_ATOMIC_WRAPPER(Mul, double) {
  uint64* address_as_ull = reinterpret_cast<uint64*>(address);
  uint64 old = *address_as_ull, assumed;
  do {
    assumed = old;
    old = atomicCAS(address_as_ull, assumed,
                    __double_as_longlong(val * __longlong_as_double(assumed)));
  } while (assumed != old);
  return __longlong_as_double(old);
}

// For atomicDiv.
CUDA_ATOMIC_WRAPPER(Div, int32) {
  int32 old = *address, assumed;
  do {
    assumed = old;
    old = atomicCAS(address, assumed, assumed / val);
  } while (assumed != old);
  return old;
}

CUDA_ATOMIC_WRAPPER(Div, uint32) {
  uint32 old = *address, assumed;
  do {
    assumed = old;
    old = atomicCAS(address, assumed, assumed / val);
  } while (assumed != old);
  return old;
}

CUDA_ATOMIC_WRAPPER(Div, uint64) {
  uint64 old = *address, assumed;
  do {
    assumed = old;
    old = atomicCAS(address, assumed, assumed / val);
  } while (assumed != old);
  return old;
}

CUDA_ATOMIC_WRAPPER(Div, float) {
  int32* address_as_int = reinterpret_cast<int32*>(address);
  int32 old = *address_as_int, assumed;
  do {
    assumed = old;
    old = atomicCAS(address_as_int, assumed,
                    __float_as_int(__int_as_float(assumed) / val));
  } while (assumed != old);
  return __int_as_float(old);
}

CUDA_ATOMIC_WRAPPER(Div, double) {
  uint64* address_as_ull = reinterpret_cast<uint64*>(address);
  uint64 old = *address_as_ull, assumed;
  do {
    assumed = old;
    old = atomicCAS(address_as_ull, assumed,
                    __double_as_longlong(__longlong_as_double(assumed) / val));
  } while (assumed != old);
  return __longlong_as_double(old);
}

#undef USE_CUDA_ATOMIC
#undef CUDA_ATOMIC_WRAPPER

template <typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T tf_min(const T& x, const T& y) {
  return x > y ? y : x;
}

template <typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T tf_max(const T& x, const T& y) {
  return x < y ? y : x;
}

template <typename T>
__device__ EIGEN_ALWAYS_INLINE T CudaShuffle(T value, int srcLane,
                                             int width = warpSize) {
  return __shfl(value, srcLane, width);
}

// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
// instead of float for lo and hi (which is incorrect with ftz, for example).
// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
// TODO(csigg): remove when the bug is fixed in the next CUDA release.
__device__ EIGEN_ALWAYS_INLINE double CudaShuffle(double value, int srcLane,
                                                  int width = warpSize) {
  unsigned lo, hi;
  asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
  hi = __shfl(hi, srcLane, width);
  lo = __shfl(lo, srcLane, width);
  asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
  return value;
}

template <typename T>
__device__ EIGEN_ALWAYS_INLINE T CudaShuffleUp(T value, int delta,
                                               int width = warpSize) {
  return __shfl_up(value, delta, width);
}

// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
// instead of float for lo and hi (which is incorrect with ftz, for example).
// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
// TODO(csigg): remove when the bug is fixed in the next CUDA release.
__device__ EIGEN_ALWAYS_INLINE double CudaShuffleUp(double value, int delta,
                                                    int width = warpSize) {
  unsigned lo, hi;
  asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
  hi = __shfl_up(hi, delta, width);
  lo = __shfl_up(lo, delta, width);
  asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
  return value;
}

template <typename T>
__device__ EIGEN_ALWAYS_INLINE T CudaShuffleDown(T value, int delta,
                                                 int width = warpSize) {
  return __shfl_down(value, delta, width);
}

// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
// instead of float for lo and hi (which is incorrect with ftz, for example).
// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
// TODO(csigg): remove when the bug is fixed in the next CUDA release.
__device__ EIGEN_ALWAYS_INLINE double CudaShuffleDown(double value, int delta,
                                                      int width = warpSize) {
  unsigned lo, hi;
  asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
  hi = __shfl_down(hi, delta, width);
  lo = __shfl_down(lo, delta, width);
  asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
  return value;
}

template <typename T>
__device__ EIGEN_ALWAYS_INLINE T CudaShuffleXor(T value, int laneMask,
                                                int width = warpSize) {
  return __shfl_xor(value, laneMask, width);
}

// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
// instead of float for lo and hi (which is incorrect with ftz, for example).
// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
// TODO(csigg): remove when the bug is fixed in the next CUDA release.
__device__ EIGEN_ALWAYS_INLINE double CudaShuffleXor(double value, int laneMask,
                                                     int width = warpSize) {
  unsigned lo, hi;
  asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
  hi = __shfl_xor(hi, laneMask, width);
  lo = __shfl_xor(lo, laneMask, width);
  asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
  return value;
}

}  // namespace tensorflow

#endif  // GOOGLE_CUDA

#endif  // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_