aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/random_op_gpu.cu.cc
blob: 3393b39faf4a25791b48af99a5e474f3e9bfbfce (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if GOOGLE_CUDA

#define EIGEN_USE_GPU

#include "tensorflow/core/kernels/random_op.h"

#include <assert.h>
#include <stdio.h>

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"

namespace tensorflow {

class OpKernelContext;

namespace functor {

typedef Eigen::GpuDevice GPUDevice;

template <class Distribution, bool VariableSamplesPerOutput>
struct FillPhiloxRandomKernel;

template <typename T, int ElementCount>
class SampleCopier {
 public:
  inline __device__ void operator()(
      T* buf, const tensorflow::random::Array<T, ElementCount>& array) const {
#pragma unroll
    for (int i = 0; i < ElementCount; i++) {
      buf[i] = array[i];
    }
  }
};

template <>
class SampleCopier<float, 4> {
 public:
  // Copies the elements from the array to buf. buf must be 128-bit aligned,
  // which is true for tensor data, and all offsets that are a multiple of the
  // vector size (because the vectors are 128 bits long).
  inline __device__ void operator()(
      float* buf, const tensorflow::random::Array<float, 4>& array) const {
    // NOTE(ringwalt): It's not safe to cast &array[0] to a float4, because they
    // have 32-bit alignment vs 128-bit alignment. There seems to be no
    // performance loss when assigning each element to a vector.
    float4 vec;
    vec.x = array[0];
    vec.y = array[1];
    vec.z = array[2];
    vec.w = array[3];
    float4* buf_vector = reinterpret_cast<float4*>(buf);
    *buf_vector = vec;
  }
};

template <>
class SampleCopier<int32, 4> {
 public:
  // Copies the elements from the array to buf. buf must be 128-bit aligned,
  // which is true for tensor data, and all offsets that are a multiple of the
  // vector size (because the vectors are 128 bits long).
  inline __device__ void operator()(
      int32* buf, const tensorflow::random::Array<int32, 4>& array) const {
    int4 vec;
    vec.x = array[0];
    vec.y = array[1];
    vec.z = array[2];
    vec.w = array[3];
    int4* buf_vector = reinterpret_cast<int4*>(buf);
    *buf_vector = vec;
  }
};

template <>
class SampleCopier<double, 2> {
 public:
  // Copies the elements from the array to buf. buf must be 128-bit aligned,
  // which is true for tensor data, and all offsets that are a multiple of the
  // vector size (because the vectors are 128 bits long).
  inline __device__ void operator()(
      double* buf, const tensorflow::random::Array<double, 2>& array) const {
    double2 vec;
    vec.x = array[0];
    vec.y = array[1];
    double2* buf_vector = reinterpret_cast<double2*>(buf);
    *buf_vector = vec;
  }
};

template <>
class SampleCopier<int64, 2> {
 public:
  // Copies the elements from the array to buf. buf must be 128-bit aligned,
  // which is true for tensor data, and all offsets that are a multiple of the
  // vector size (because the vectors are 128 bits long).
  inline __device__ void operator()(
      int64* buf, const tensorflow::random::Array<int64, 2>& array) const {
    longlong2 vec;
    vec.x = array[0];
    vec.y = array[1];
    longlong2* buf_vector = reinterpret_cast<longlong2*>(buf);
    *buf_vector = vec;
  }
};

// A cuda kernel to fill the data with random numbers from the specified
// distribution. Each output takes a fixed number of samples.
template <class Distribution>
struct FillPhiloxRandomKernel<Distribution, false> {
  typedef typename Distribution::ResultElementType T;
  PHILOX_DEVICE_FUNC void Run(random::PhiloxRandom gen, T* data, int64 size,
                              Distribution dist) {
    const int kGroupSize = Distribution::kResultElementCount;

    const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x;
    const int32 total_thread_count = gridDim.x * blockDim.x;
    int32 offset = thread_id * kGroupSize;
    gen.Skip(thread_id);

    const SampleCopier<T, kGroupSize> copier;
    while (offset + kGroupSize <= size) {
      const typename Distribution::ResultType samples = dist(&gen);
      copier(&data[offset], samples);

      offset += total_thread_count * kGroupSize;
      gen.Skip(total_thread_count - 1);
    }

    typename Distribution::ResultType samples = dist(&gen);
    for (int i = 0; i < kGroupSize; ++i) {
      if (offset >= size) {
        return;
      }
      data[offset] = samples[i];
      ++offset;
    }
  }
};

// A cuda kernel to fill the data with random numbers from the specified
// distribution. Each output takes a variable number of samples.
template <class Distribution>
struct FillPhiloxRandomKernel<Distribution, true> {
  typedef typename Distribution::ResultElementType T;
  PHILOX_DEVICE_FUNC void Run(const random::PhiloxRandom& base_gen, T* data,
                              int64 size, Distribution dist) {
    using random::PhiloxRandom;
    using random::SingleSampleAdapter;

    const int kReservedSamplesPerOutput = 256;
    const int kGroupSize = Distribution::kResultElementCount;
    const int kGeneratorSkipPerOutputGroup = kGroupSize *
                                             kReservedSamplesPerOutput /
                                             PhiloxRandom::kResultElementCount;

    const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x;
    const int32 total_thread_count = gridDim.x * blockDim.x;
    int64 group_index = thread_id;
    int64 offset = group_index * kGroupSize;

    while (offset < size) {
      // Since each output takes a variable number of samples, we need to
      // realign the generator to the beginning for the current output group
      PhiloxRandom gen = base_gen;
      gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
      SingleSampleAdapter<PhiloxRandom> single_samples(&gen);

      typename Distribution::ResultType samples = dist(&single_samples);

      for (int i = 0; i < kGroupSize; ++i) {
        if (offset >= size) {
          return;
        }
        data[offset] = samples[i];
        ++offset;
      }

      offset += (total_thread_count - 1) * kGroupSize;
      group_index += total_thread_count;
    }
  }
};

// A simple launch pad to call the correct function templates to fill the data
template <class Distribution>
__global__ void __launch_bounds__(1024)
    FillPhiloxRandomKernelLaunch(random::PhiloxRandom base_gen,
                                 typename Distribution::ResultElementType* data,
                                 int64 size, Distribution dist) {
  FillPhiloxRandomKernel<Distribution,
                         Distribution::kVariableSamplesPerOutput>()
      .Run(base_gen, data, size, dist);
}

// Partial specialization for GPU
template <class Distribution>
void FillPhiloxRandom<GPUDevice, Distribution>::operator()(
    OpKernelContext*, const GPUDevice& d, random::PhiloxRandom gen,
    typename Distribution::ResultElementType* data, int64 size,
    Distribution dist) {
  const int32 block_size = d.maxCudaThreadsPerBlock();
  const int32 num_blocks =
      (d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor()) /
      block_size;

  FillPhiloxRandomKernelLaunch<Distribution>
      <<<num_blocks, block_size, 0, d.stream()>>>(gen, data, size, dist);
};

// Explicit instantiation of the GPU distributions functors
// clang-format off
// NVCC cannot handle ">>" properly
template struct FillPhiloxRandom<
    GPUDevice, random::UniformDistribution<random::PhiloxRandom, Eigen::half> >;
template struct FillPhiloxRandom<
    GPUDevice, random::UniformDistribution<random::PhiloxRandom, float> >;
template struct FillPhiloxRandom<
    GPUDevice, random::UniformDistribution<random::PhiloxRandom, double> >;
template struct FillPhiloxRandom<
    GPUDevice, random::UniformDistribution<random::PhiloxRandom, int32> >;
template struct FillPhiloxRandom<
    GPUDevice, random::UniformDistribution<random::PhiloxRandom, int64> >;
template struct FillPhiloxRandom<
    GPUDevice, random::NormalDistribution<random::PhiloxRandom, Eigen::half> >;
template struct FillPhiloxRandom<
    GPUDevice, random::NormalDistribution<random::PhiloxRandom, float> >;
template struct FillPhiloxRandom<
    GPUDevice, random::NormalDistribution<random::PhiloxRandom, double> >;
template struct FillPhiloxRandom<
    GPUDevice, random::TruncatedNormalDistribution<
        random::SingleSampleAdapter<random::PhiloxRandom>, Eigen::half> >;
template struct FillPhiloxRandom<
    GPUDevice, random::TruncatedNormalDistribution<
                   random::SingleSampleAdapter<random::PhiloxRandom>, float> >;
template struct FillPhiloxRandom<
    GPUDevice, random::TruncatedNormalDistribution<
                   random::SingleSampleAdapter<random::PhiloxRandom>, double> >;
// clang-format on

}  // namespace functor
}  // namespace tensorflow

#endif  // GOOGLE_CUDA