aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/check_numerics_op_gpu.cu.cc
blob: cb84f987312458d4e5a9719b597dcc8d34685f99 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#if GOOGLE_CUDA
#define EIGEN_USE_GPU

#include <stdio.h>
#include <assert.h>

#include <math.h>
#include <algorithm>

#include "tensorflow/core/platform/port.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace tensorflow {

namespace {

typedef Eigen::GpuDevice GPUDevice;

// A Cuda kernel to check if each element is Inf or Nan. If any exists, the
// relevant elements in abnormal_detected will be set
template <typename T>
__global__ void CheckNumericsKernel(const T *data, int size,
                                    int abnormal_detected[2]) {
  const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x;
  const int32 total_thread_count = gridDim.x * blockDim.x;

  int32 offset = thread_id;

  while (offset < size) {
    if (isnan(data[offset])) {
      abnormal_detected[0] = 1;
    }
    if (isinf(data[offset])) {
      abnormal_detected[1] = 1;
    }
    offset += total_thread_count;
  }
}

}  // namespace

// A simple launch pad to launch the Cuda kernels that checks the numerical
// abnormality in the given array
template <typename T>
struct CheckNumericsLaunch {
  void Run(const GPUDevice &d, const T *data, int size,
           int abnormal_detected[2]) {
    const int32 block_size = d.maxCudaThreadsPerBlock();
    const int32 num_blocks =
        (d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor()) /
        block_size;

    CheckNumericsKernel<T><<<num_blocks, block_size, 0, d.stream()>>>(
        data, size, abnormal_detected);
  }
};

template struct CheckNumericsLaunch<float>;
template struct CheckNumericsLaunch<double>;

}  // namespace tensorflow
#endif  // GOOGLE_CUDA