aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
blob: e7882acc80e3c2383f3a3c208175d16dd8c092ab (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// The algorithm for dynamic partition has the following steps:
// 1. Let N be the size of partitions. We initialize a new vector indices_in
//    with the values 0, 1, 2, ..., N-1.
// 2. We apply cub::DeviceRadixSort::SortPairs to the key - value pairs given
//    by partitions and indices_in. This will result in two new vectors
//    partitions_out and indices_out, with partitions_out sorted.
// 3. The first dimension of outputs[i] is equal to the number of i-values in
//    partitions_out. We determine it in two steps:
//    - apply cub::DeviceReduce::ReduceByKey to count how many times each value
//      appears in partitions_out,
//    - move the results to partition_count. This handles missing values
//      (corresponding to empty parts).
// 4. Because partition_count is on the GPU, we bring it asynchronously to
//    the CPU. Then we can allocate the output tensors.
// 5. Finally, we use indices_out and the gather functor to collect the output.
//    This works, because for each interval of i-values, indices_out points
//    to the slices which should form output[i].

#if GOOGLE_CUDA

#define EIGEN_USE_GPU

#include "third_party/cub/device/device_radix_sort.cuh"
#include "third_party/cub/device/device_reduce.cuh"
#include "third_party/cub/iterator/constant_input_iterator.cuh"
#include "third_party/cub/thread/thread_operators.cuh"
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/gather_functor_gpu.cu.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/transform_output_iterator.h"

namespace tensorflow {

typedef Eigen::GpuDevice GPUDevice;

namespace {

template <typename T>
__global__ void RangeInitKernel(const T start, const T delta, const int32 size,
                                T* out) {
  CUDA_1D_KERNEL_LOOP(i, size) { out[i] = start + i * delta; }
}

__global__ void MoveValuesKernel(const int32* keys, const int32* values,
                                 const int32* size, int32 out_size,
                                 int32* out) {
  int32 N = min(ldg(size), out_size);
  CUDA_1D_KERNEL_LOOP(i, N) {
    int32 key = ldg(keys + i);
    int32 value = ldg(values + i);
    if (FastBoundsCheck(key, out_size)) out[key] = value;
  }
}

// Initialize out with range start, start + delta, start + 2 * delta, ...
// This is needed because tf.range has no GPU implementation.
template <typename T>
void RangeInit(const GPUDevice& d, const T start, const T delta,
               const int32 size, typename TTypes<T>::Flat out) {
  CudaLaunchConfig config = GetCudaLaunchConfig(size, d);
  RangeInitKernel<T>
      <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
          start, delta, size, out.data());
}

// Given *num_runs pairs (key, value), this function moves the value
// corresponding to key i at position i in the array out.
void MoveValues(const GPUDevice& d, int32* keys, int32* values, int32* num_runs,
                int32 out_size, int32* out) {
  // Because num_runs is located on the GPU, we can not access it directly.
  // So we launch the kernel with size = out_size.
  // This is valid for correct inputs, because then out_size >= *num_runs.
  // For wrong inputs, we may have out_size < *num_runs. In this case we will
  // only handle the first out_size values.
  CudaLaunchConfig config = GetCudaLaunchConfig(out_size, d);
  MoveValuesKernel<<<config.block_count, config.thread_per_block, 0,
                     d.stream()>>>(keys, values, num_runs, out_size, out);
}

template <typename T>
void CallGatherKernel(const GPUDevice& d, const T* params, const int32* indices,
                      T* out, int64 gather_dim_size, int64 indices_size,
                      int64 slice_size, int64 out_size) {
  CudaLaunchConfig config = GetCudaLaunchConfig(out_size, d);
  GatherOpKernel<T, int32, true>
      <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
          params, indices, out, gather_dim_size, indices_size, slice_size,
          out_size);
}

struct IdentityOp {
  __device__ int32 __forceinline__ operator()(const int32& a) const {
    return a;
  }
};

// Define an output iterator that only allows assignment to
// positions between [base, base + limit).
class BoundedOutputIterator
    : public TransformOutputIterator<int32, int32, IdentityOp> {
 private:
  int32 limit;
  int32* base;

  struct BoundedReference : Reference {
    int32 limit;
    int32* base;
    // Constructor
    __host__ __device__ __forceinline__
    BoundedReference(int32* ptr, int32* base, IdentityOp op, int32 limit)
        : Reference(ptr, op), limit(limit), base(base) {}

    // Assignment
    __host__ __device__ __forceinline__ int32 operator=(int32 val) {
      if (ptr - base < limit && ptr - base >= 0) *ptr = val;
      return val;
    }
  };

 public:
  typedef BoundedOutputIterator self_type;
  typedef BoundedReference reference;

  __host__ __device__ __forceinline__ BoundedOutputIterator(int32* ptr,
                                                            IdentityOp op,
                                                            int32 size)
      : TransformOutputIterator(ptr, op), limit(size), base(ptr) {}

  __host__ __device__ __forceinline__
  BoundedOutputIterator(int32* ptr, int32* base, IdentityOp op, int32 size)
      : TransformOutputIterator(ptr, op), limit(size), base(base) {}

  // Indirection
  __host__ __device__ __forceinline__ reference operator*() const {
    return BoundedReference(ptr, base, conversion_op, limit);
  }

  // Array subscript
  __host__ __device__ __forceinline__ reference operator[](int32 n) const {
    return BoundedReference(ptr + n, base, conversion_op, limit);
  }

  // Addition
  __host__ __device__ __forceinline__ self_type operator+(int32 n) const {
    self_type retval(ptr + n, base, conversion_op, limit);
    return retval;
  }

  // Subtraction
  __host__ __device__ __forceinline__ self_type operator-(int32 n) const {
    self_type retval(ptr - n, base, conversion_op, limit);
    return retval;
  }
};

}  // namespace

// The current implementation has memory cost on GPU
// I + P + max(3N + R + P, O + N), where:
// I - the size of the input
// N - the size of the partitions tensor
// R - the temporary storage used by cub::RadixSort, about 2N
// P - the number of partitions
// O - the size of the output
// So roughly the cost is I + P + max(5N, O + N).
template <typename T>
class DynamicPartitionOpGPU : public AsyncOpKernel {
 public:
  explicit DynamicPartitionOpGPU(OpKernelConstruction* c) : AsyncOpKernel(c) {
    OP_REQUIRES_OK(c, c->GetAttr("num_partitions", &num_partitions_));
    OP_REQUIRES(c, num_partitions_ >= 1,
                errors::InvalidArgument("num_partitions must be at least 1"));
  }

  void AllocateTempSpace(OpKernelContext* c, int32 N, Tensor* indices_in,
                         Tensor* partitions_out, Tensor* indices_out,
                         DoneCallback done) {
    int32 M = std::max(N, num_partitions_);
    // indices_in will be made slightly larger to accommodate
    // later computations.
    OP_REQUIRES_OK_ASYNC(
        c, c->allocate_temp(DT_INT32, TensorShape({M}), indices_in), done);
    OP_REQUIRES_OK_ASYNC(
        c, c->allocate_temp(DT_INT32, TensorShape({N}), partitions_out), done);
    OP_REQUIRES_OK_ASYNC(
        c, c->allocate_temp(DT_INT32, TensorShape({N}), indices_out), done);
  }

  void AllocateOutputs(OpKernelContext* c, const Tensor* data,
                       const Tensor* partitions, const Tensor* partition_count,
                       OpOutputList* Tout, DoneCallback done) {
    auto e_part_count = partition_count->flat<int32>();
    // Allocate output tensors of the right size
    OP_REQUIRES_OK_ASYNC(c, c->output_list("outputs", Tout), done);
    for (int p = 0; p < num_partitions_; p++) {
      TensorShape shape;
      shape.AddDim(e_part_count(p));
      for (int i = partitions->dims(); i < data->dims(); i++) {
        shape.AddDim(data->dim_size(i));
      }
      Tensor* out;
      OP_REQUIRES_OK_ASYNC(c, Tout->allocate(p, shape, &out), done);
    }
  }

  void ComputeAsync(OpKernelContext* c, DoneCallback done) {
    const Tensor& data = c->input(0);
    const Tensor& partitions = c->input(1);

    OP_REQUIRES_ASYNC(
        c, TensorShapeUtils::StartsWith(data.shape(), partitions.shape()),
        errors::InvalidArgument(
            "data.shape must start with partitions.shape, ",
            "got data.shape = ", data.shape().DebugString(),
            ", partitions.shape = ", partitions.shape().DebugString()),
        done);

    Tensor partition_count;

    // We must handle the case of empty partitions separately,
    // because kernels don't work with 0-sized tensors.
    if (partitions.NumElements() == 0) {
      AllocatorAttributes alloc_attr;
      alloc_attr.set_on_host(true);
      OP_REQUIRES_OK_ASYNC(
          c,
          c->allocate_temp(DT_INT32, TensorShape({num_partitions_}),
                           &partition_count, alloc_attr),
          done);
      auto e_part_count = partition_count.flat<int32>();
      for (int i = 0; i < num_partitions_; i++) e_part_count(i) = 0;
      OpOutputList outputs;
      this->AllocateOutputs(c, &data, &partitions, &partition_count, &outputs,
                            done);
      if (c->status().ok()) done();
      return;
    }

    // Prepare for counting.
    OP_REQUIRES_OK_ASYNC(
        c,
        c->allocate_temp(DT_INT32, TensorShape({num_partitions_}),
                         &partition_count),
        done);
    Tensor indices_out;
    // Count how many times each partition index occurs.
    // Also sort the info in partitions and output it in indices_out,
    // in preparation for the next step.
    this->CountAndSortParts(c, &partitions, &partition_count, &indices_out,
                            done);
    if (!c->status().ok()) return;

    // In order to allocate the output tensor we have to move partition_count
    // to CPU.
    auto* stream = c->op_device_context()->stream();
    OP_REQUIRES_ASYNC(c, stream, errors::Internal("No GPU stream available."),
                      done);
    Tensor cpu_tensor;
    AllocatorAttributes alloc_attr;
    alloc_attr.set_on_host(true);
    alloc_attr.set_gpu_compatible(true);
    OP_REQUIRES_OK_ASYNC(
        c,
        c->allocate_temp(partition_count.dtype(), partition_count.shape(),
                         &cpu_tensor, alloc_attr),
        done);
    se::DeviceMemoryBase wrapped(partition_count.flat<int32>().data(),
                                 num_partitions_ * sizeof(int32));
    const bool status =
        stream
            ->ThenMemcpy(cpu_tensor.flat<int32>().data(), wrapped,
                         num_partitions_ * sizeof(int32))
            .ok();
    OP_REQUIRES_ASYNC(
        c, status,
        errors::Internal("Failed to launch copy from device to host."), done);

    // Keep a reference to partition_count so that the buffer
    // is not deallocated at the end of the function, before
    // memcpy is completed.
    TensorReference partition_ref(partition_count);
    auto wrapped_callback = [this, c, &data, &partitions, indices_out,
                             partition_ref, cpu_tensor, done]() {
      OpOutputList outputs;
      this->AllocateOutputs(c, &data, &partitions, &cpu_tensor, &outputs, done);
      if (!c->status().ok()) {
        partition_ref.Unref();
        return;
      }
      int32 N = partitions.NumElements();
      int64 slice_size = data.NumElements() / N;
      this->GatherSlices(c, &data, &indices_out, N, slice_size, outputs);
      partition_ref.Unref();
      done();
    };

    c->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
        stream, wrapped_callback);
  }

 protected:
  void RadixSort(OpKernelContext* c, const Tensor* partitions,
                 Tensor* indices_in, Tensor* partitions_out,
                 Tensor* indices_out, DoneCallback done) {
    int32 N = partitions->NumElements();
    const GPUDevice& device = c->eigen_device<GPUDevice>();
    const cudaStream_t& cu_stream = GetCudaStream(c);

    // Initialize the indices_in tensor using the Range GPU kernel.
    RangeInit(device, 0, 1, N, indices_in->flat<int32>());
    // Obtain the pointers to inner buffers.
    const int32* partitions_ptr = partitions->flat<int32>().data();
    int32* partitions_out_ptr = partitions_out->flat<int32>().data();
    int32* indices_in_ptr = indices_in->flat<int32>().data();
    int32* indices_out_ptr = indices_out->flat<int32>().data();
    // Determine temporary device storage requirements.
    Tensor cub_temp_storage;
    size_t temp_storage_bytes = 0;
    cub::DeviceRadixSort::SortPairs(
        NULL, temp_storage_bytes, partitions_ptr, partitions_out_ptr,
        indices_in_ptr, indices_out_ptr, N, 0, sizeof(int32) * 8, cu_stream);
    // Allocate temporary storage.
    OP_REQUIRES_OK_ASYNC(
        c,
        c->allocate_temp(DT_INT8,
                         TensorShape({static_cast<int64>(temp_storage_bytes)}),
                         &cub_temp_storage),
        done);
    // Radix-sort the partition information.
    cub::DeviceRadixSort::SortPairs(
        cub_temp_storage.flat<int8>().data(), temp_storage_bytes,
        partitions_ptr, partitions_out_ptr, indices_in_ptr, indices_out_ptr, N,
        0, sizeof(int32) * 8, cu_stream);
  }  // At this point cub_temp_storage will be marked for deallocation.

  void CountAndSortParts(OpKernelContext* c, const Tensor* partitions,
                         Tensor* partition_count, Tensor* indices_out,
                         DoneCallback done) {
    const GPUDevice& device = c->eigen_device<GPUDevice>();
    const cudaStream_t& cu_stream = GetCudaStream(c);
    int32 N = partitions->NumElements();
    Tensor indices_in;
    Tensor partitions_out;
    Tensor aggregates_out;

    // Allocate memory for Radix-Sort.
    this->AllocateTempSpace(c, N, &indices_in, &partitions_out, indices_out,
                            done);
    if (!c->status().ok()) return;
    this->RadixSort(c, partitions, &indices_in, &partitions_out, indices_out,
                    done);
    if (!c->status().ok()) return;
    // We will now apply a reduce operation to count how many times
    // each index appears in partitions.

    // Zero-out the partition_count tensor.
    functor::SetZeroFunctor<GPUDevice, int32> zero_functor;
    zero_functor(device, partition_count->flat<int32>());
    // Allocate memory for aggregates_out.
    OP_REQUIRES_OK_ASYNC(
        c,
        c->allocate_temp(DT_INT32, TensorShape({num_partitions_}),
                         &aggregates_out),
        done);
    // Obtain the pointers to inner buffers.
    int32* keys_in_ptr = partitions_out.flat<int32>().data();
    // Here we reuse the indices_in tensor for the unique keys output.
    int32* unique_out_ptr = indices_in.flat<int32>().data();
    int32* aggregates_out_ptr = aggregates_out.flat<int32>().data();
    // We wrap the pointers in bounded output iterators to guard against
    // wrong inputs (more than num_partitions distinct indices).
    IdentityOp id_op;
    BoundedOutputIterator unique_out_it(unique_out_ptr, id_op, num_partitions_);
    BoundedOutputIterator aggregates_out_it(aggregates_out_ptr, id_op,
                                            num_partitions_);

    cub::ConstantInputIterator<int32> values_in(1);
    cub::Sum reduction_op;

    // Allocate space on GPU for the number of runs. This is required by CUB.
    Tensor num_runs;
    OP_REQUIRES_OK_ASYNC(
        c, c->allocate_temp(DT_INT32, TensorShape({1}), &num_runs), done);
    int32* num_runs_ptr = num_runs.flat<int32>().data();

    // Determine temporary device storage requirements
    Tensor cub_temp_storage;
    size_t temp_storage_bytes = 0;
    cub::DeviceReduce::ReduceByKey(NULL, temp_storage_bytes, keys_in_ptr,
                                   unique_out_it, values_in, aggregates_out_it,
                                   num_runs_ptr, reduction_op, N, cu_stream);
    // Allocate temporary storage.
    OP_REQUIRES_OK_ASYNC(
        c,
        c->allocate_temp(DT_INT8,
                         TensorShape({static_cast<int64>(temp_storage_bytes)}),
                         &cub_temp_storage),
        done);
    // Run reduce-by-key. The effect is that we count how many times
    // each index appears in partitions. The distinct indices are stored
    // in unique_out, while the count is stored in aggregates_out.
    // The total number of distinct indices is stored in num_runs.
    cub::DeviceReduce::ReduceByKey(cub_temp_storage.flat<int8>().data(),
                                   temp_storage_bytes, keys_in_ptr,
                                   unique_out_it, values_in, aggregates_out_it,
                                   num_runs_ptr, reduction_op, N, cu_stream);
    // We are not done yet. unique_out only contains the indices that appeared
    // at least once in partitions. We move each value from aggregates_out
    // to the corresponding position in partition_count. This will handle
    // possibly empty parts.
    MoveValues(device, unique_out_ptr, aggregates_out_ptr, num_runs_ptr,
               num_partitions_, partition_count->flat<int32>().data());
  }  // At this point indices_in, partitions_out, aggregates_out
     // and cub_temp_storage will be marked for deallocation.

  void GatherSlices(OpKernelContext* c, const Tensor* data,
                    const Tensor* indices, int32 N, int64 slice_size,
                    OpOutputList& outs) {
    const GPUDevice& device = c->eigen_device<GPUDevice>();
    const int32* ind_base = indices->flat<int32>().data();
    const T* data_base = data->flat<T>().data();

    for (int p = 0; p < num_partitions_; p++) {
      int32 indices_size = outs[p]->dim_size(0);
      int64 out_size = outs[p]->NumElements();
      T* out_base = outs[p]->flat<T>().data();
      if (out_size > 0)
        CallGatherKernel<T>(device, data_base, ind_base, out_base, N,
                            indices_size, slice_size, out_size);
      ind_base += indices_size;
    }
  }

  int32 num_partitions_;
};

#define REGISTER_DYNAMIC_PARTITION_GPU(T)                                 \
  REGISTER_KERNEL_BUILDER(                                                \
      Name("DynamicPartition").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
      DynamicPartitionOpGPU<T>)

TF_CALL_GPU_NUMBER_TYPES(REGISTER_DYNAMIC_PARTITION_GPU);
TF_CALL_complex64(REGISTER_DYNAMIC_PARTITION_GPU);
TF_CALL_complex128(REGISTER_DYNAMIC_PARTITION_GPU);
#undef REGISTER_DYNAMIC_PARTITION_GPU

}  // namespace tensorflow

#endif  // GOOGLE_CUDA