aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
blob: 8ab08fb93aeef2651f2911047d91216c85392705 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// See docs in ../ops/image_ops.cc.

#if GOOGLE_CUDA

#define EIGEN_USE_GPU

#include "tensorflow/core/kernels/crop_and_resize_op.h"

#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"

namespace tensorflow {

typedef Eigen::GpuDevice GPUDevice;

namespace {

enum InterpolationMethod {
  BILINEAR = 0,
  NEAREST = 1,
};

template <typename T>
__global__ void CropAndResizeKernel(
    const int32 nthreads, const T* image_ptr, const float* boxes_ptr,
    const int32* box_ind_ptr, int num_boxes, int batch, int image_height,
    int image_width, int crop_height, int crop_width, int depth, int method_id,
    float extrapolation_value, float* crops_ptr) {
  CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
    // out_idx = d + depth * (w + crop_width * (h + crop_height * b))
    int idx = out_idx;
    const int d = idx % depth;
    idx /= depth;
    const int x = idx % crop_width;
    idx /= crop_width;
    const int y = idx % crop_height;
    const int b = idx / crop_height;

    const float y1 = boxes_ptr[b * 4];
    const float x1 = boxes_ptr[b * 4 + 1];
    const float y2 = boxes_ptr[b * 4 + 2];
    const float x2 = boxes_ptr[b * 4 + 3];

    const int32 b_in = box_ind_ptr[b];
    if (b_in < 0 || b_in >= batch) {
      continue;
    }

    const float height_scale =
        (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
                          : 0;
    const float width_scale =
        (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) : 0;

    const float in_y = (crop_height > 1)
                           ? y1 * (image_height - 1) + y * height_scale
                           : 0.5 * (y1 + y2) * (image_height - 1);
    if (in_y < 0 || in_y > image_height - 1) {
      crops_ptr[out_idx] = extrapolation_value;
      continue;
    }

    const float in_x = (crop_width > 1)
                           ? x1 * (image_width - 1) + x * width_scale
                           : 0.5 * (x1 + x2) * (image_width - 1);
    if (in_x < 0 || in_x > image_width - 1) {
      crops_ptr[out_idx] = extrapolation_value;
      continue;
    }

    if (method_id == BILINEAR) {
      const int top_y_index = floorf(in_y);
      const int bottom_y_index = ceilf(in_y);
      const float y_lerp = in_y - top_y_index;

      const int left_x_index = floorf(in_x);
      const int right_x_index = ceilf(in_x);
      const float x_lerp = in_x - left_x_index;

      const float top_left(static_cast<float>(
          image_ptr[((b_in * image_height + top_y_index) * image_width +
                     left_x_index) *
                        depth +
                    d]));
      const float top_right(static_cast<float>(
          image_ptr[((b_in * image_height + top_y_index) * image_width +
                     right_x_index) *
                        depth +
                    d]));
      const float bottom_left(static_cast<float>(
          image_ptr[((b_in * image_height + bottom_y_index) * image_width +
                     left_x_index) *
                        depth +
                    d]));
      const float bottom_right(static_cast<float>(
          image_ptr[((b_in * image_height + bottom_y_index) * image_width +
                     right_x_index) *
                        depth +
                    d]));
      const float top = top_left + (top_right - top_left) * x_lerp;
      const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
      crops_ptr[out_idx] = top + (bottom - top) * y_lerp;
    } else {  // method_id == kMethodNearestId
      const int closest_x_index = roundf(in_x);
      const int closest_y_index = roundf(in_y);
      crops_ptr[out_idx] = static_cast<float>(
          image_ptr[((b_in * image_height + closest_y_index) * image_width +
                     closest_x_index) *
                        depth +
                    d]);
    }
  }
}

template <typename T>
__global__ void CropAndResizeBackpropImageKernel(
    const int32 nthreads, const float* grads_ptr, const float* boxes_ptr,
    const int32* box_ind_ptr, int num_boxes, int batch, int image_height,
    int image_width, int crop_height, int crop_width, int depth,
    T* grads_image_ptr, int method_id) {
  CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
    // out_idx = d + depth * (w + crop_width * (h + crop_height * b))
    int idx = out_idx;
    const int d = idx % depth;
    idx /= depth;
    const int x = idx % crop_width;
    idx /= crop_width;
    const int y = idx % crop_height;
    const int b = idx / crop_height;

    const float y1 = boxes_ptr[b * 4];
    const float x1 = boxes_ptr[b * 4 + 1];
    const float y2 = boxes_ptr[b * 4 + 2];
    const float x2 = boxes_ptr[b * 4 + 3];

    const int32 b_in = box_ind_ptr[b];
    if (b_in < 0 || b_in >= batch) {
      continue;
    }

    const float height_scale =
        (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1)
                          : 0;
    const float width_scale =
        (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) : 0;

    const float in_y = (crop_height > 1)
                           ? y1 * (image_height - 1) + y * height_scale
                           : 0.5 * (y1 + y2) * (image_height - 1);
    if (in_y < 0 || in_y > image_height - 1) {
      continue;
    }

    const float in_x = (crop_width > 1)
                           ? x1 * (image_width - 1) + x * width_scale
                           : 0.5 * (x1 + x2) * (image_width - 1);
    if (in_x < 0 || in_x > image_width - 1) {
      continue;
    }

    if (method_id == BILINEAR) {
      const int top_y_index = floorf(in_y);
      const int bottom_y_index = ceilf(in_y);
      const float y_lerp = in_y - top_y_index;

      const int left_x_index = floorf(in_x);
      const int right_x_index = ceilf(in_x);
      const float x_lerp = in_x - left_x_index;

      const float dtop = (1 - y_lerp) * grads_ptr[out_idx];
      CudaAtomicAdd(grads_image_ptr +
                        ((b_in * image_height + top_y_index) * image_width +
                         left_x_index) *
                            depth +
                        d,
                    static_cast<T>((1 - x_lerp) * dtop));
      CudaAtomicAdd(grads_image_ptr +
                        ((b_in * image_height + top_y_index) * image_width +
                         right_x_index) *
                            depth +
                        d,
                    static_cast<T>(x_lerp * dtop));

      const float dbottom = y_lerp * grads_ptr[out_idx];
      CudaAtomicAdd(grads_image_ptr +
                        ((b_in * image_height + bottom_y_index) * image_width +
                         left_x_index) *
                            depth +
                        d,
                    static_cast<T>((1 - x_lerp) * dbottom));
      CudaAtomicAdd(grads_image_ptr +
                        ((b_in * image_height + bottom_y_index) * image_width +
                         right_x_index) *
                            depth +
                        d,
                    static_cast<T>(x_lerp * dbottom));
    } else {  // method_id == NEAREST
      const int closest_x_index = roundf(in_x);
      const int closest_y_index = roundf(in_y);
      CudaAtomicAdd(grads_image_ptr +
                        ((b_in * image_height + closest_y_index) * image_width +
                         closest_x_index) *
                            depth +
                        d,
                    static_cast<T>(grads_ptr[out_idx]));
    }
  }
}

template <typename T>
__global__ void CropAndResizeBackpropBoxesKernel(
    const int32 nthreads, const float* grads_ptr, const T* image_ptr,
    const float* boxes_ptr, const int32* box_ind_ptr, int num_boxes, int batch,
    int image_height, int image_width, int crop_height, int crop_width,
    int depth, float* grads_boxes_ptr) {
  CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
    // out_idx = d + depth * (w + crop_width * (h + crop_height * b))
    int idx = out_idx;
    const int d = idx % depth;
    idx /= depth;
    const int x = idx % crop_width;
    idx /= crop_width;
    const int y = idx % crop_height;
    const int b = idx / crop_height;

    const float y1 = boxes_ptr[b * 4];
    const float x1 = boxes_ptr[b * 4 + 1];
    const float y2 = boxes_ptr[b * 4 + 2];
    const float x2 = boxes_ptr[b * 4 + 3];

    const int32 b_in = box_ind_ptr[b];
    if (b_in < 0 || b_in >= batch) {
      continue;
    }

    const float height_ratio =
        (crop_height > 1)
            ? static_cast<float>(image_height - 1) / (crop_height - 1)
            : 0;
    const float width_ratio =
        (crop_width > 1)
            ? static_cast<float>(image_width - 1) / (crop_width - 1)
            : 0;

    const float height_scale = (crop_height > 1) ? (y2 - y1) * height_ratio : 0;
    const float width_scale = (crop_width > 1) ? (x2 - x1) * width_ratio : 0;

    const float in_y = (crop_height > 1)
                           ? y1 * (image_height - 1) + y * height_scale
                           : 0.5 * (y1 + y2) * (image_height - 1);
    if (in_y < 0 || in_y > image_height - 1) {
      continue;
    }

    const float in_x = (crop_width > 1)
                           ? x1 * (image_width - 1) + x * width_scale
                           : 0.5 * (x1 + x2) * (image_width - 1);
    if (in_x < 0 || in_x > image_width - 1) {
      continue;
    }

    const int top_y_index = floorf(in_y);
    const int bottom_y_index = ceilf(in_y);
    const float y_lerp = in_y - top_y_index;

    const int left_x_index = floorf(in_x);
    const int right_x_index = ceilf(in_x);
    const float x_lerp = in_x - left_x_index;

    const float top_left(static_cast<float>(
        image_ptr[((b_in * image_height + top_y_index) * image_width +
                   left_x_index) *
                      depth +
                  d]));
    const float top_right(static_cast<float>(
        image_ptr[((b_in * image_height + top_y_index) * image_width +
                   right_x_index) *
                      depth +
                  d]));
    const float bottom_left(static_cast<float>(
        image_ptr[((b_in * image_height + bottom_y_index) * image_width +
                   left_x_index) *
                      depth +
                  d]));
    const float bottom_right(static_cast<float>(
        image_ptr[((b_in * image_height + bottom_y_index) * image_width +
                   right_x_index) *
                      depth +
                  d]));

    // Compute the image gradient.
    float image_grad_y = (1 - x_lerp) * (bottom_left - top_left) +
                         x_lerp * (bottom_right - top_right);
    float image_grad_x = (1 - y_lerp) * (top_right - top_left) +
                         y_lerp * (bottom_right - bottom_left);
    // Modulate the image gradient with the incoming gradient.
    const float top_grad = grads_ptr[out_idx];
    image_grad_y *= top_grad;
    image_grad_x *= top_grad;

    float dy1, dy2;
    if (crop_height > 1) {
      dy1 = image_grad_y * (image_height - 1 - y * height_ratio);
      dy2 = image_grad_y * (y * height_ratio);
    } else {
      dy1 = image_grad_y * 0.5 * (image_height - 1);
      dy2 = image_grad_y * 0.5 * (image_height - 1);
    }

    float dx1, dx2;
    if (crop_width > 1) {
      dx1 = image_grad_x * (image_width - 1 - x * width_ratio);
      dx2 = image_grad_x * (x * width_ratio);
    } else {
      dx1 = image_grad_x * 0.5 * (image_width - 1);
      dx2 = image_grad_x * 0.5 * (image_width - 1);
    }

    CudaAtomicAdd(grads_boxes_ptr + b * 4 + 0, dy1);
    CudaAtomicAdd(grads_boxes_ptr + b * 4 + 1, dx1);
    CudaAtomicAdd(grads_boxes_ptr + b * 4 + 2, dy2);
    CudaAtomicAdd(grads_boxes_ptr + b * 4 + 3, dx2);
  }
}

}  // namespace

namespace functor {

template <typename T>
struct CropAndResize<GPUDevice, T> {
  bool operator()(const OpKernelContext* context,
                  typename TTypes<T, 4>::ConstTensor image,
                  typename TTypes<float, 2>::ConstTensor boxes,
                  typename TTypes<int32, 1>::ConstTensor box_ind,
                  string method_name, float extrapolation_value,
                  typename TTypes<float, 4>::Tensor crops) {
    const int batch = image.dimension(0);
    const int image_height = image.dimension(1);
    const int image_width = image.dimension(2);

    const int num_boxes = crops.dimension(0);
    const int crop_height = crops.dimension(1);
    const int crop_width = crops.dimension(2);
    const int depth = crops.dimension(3);

    const int total_count = num_boxes * crop_height * crop_width * depth;
    const GPUDevice& d = context->eigen_device<GPUDevice>();

    InterpolationMethod method = BILINEAR;
    if (method_name == "nearest") {
      method = NEAREST;
    }

    if (total_count > 0) {
      CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d);
      CropAndResizeKernel<<<config.block_count, config.thread_per_block, 0,
                            d.stream()>>>(
          config.virtual_thread_count, image.data(), boxes.data(),
          box_ind.data(), num_boxes, batch, image_height, image_width,
          crop_height, crop_width, depth, method, extrapolation_value,
          crops.data());
    }
    return d.ok();
  }
};

template <typename T>
struct CropAndResizeBackpropImage<GPUDevice, T> {
  bool operator()(const GPUDevice& d,
                  typename TTypes<float, 4>::ConstTensor grads,
                  typename TTypes<float, 2>::ConstTensor boxes,
                  typename TTypes<int32, 1>::ConstTensor box_ind,
                  typename TTypes<T, 4>::Tensor grads_image,
                  const string& method_name) {
    const int batch = grads_image.dimension(0);
    const int image_height = grads_image.dimension(1);
    const int image_width = grads_image.dimension(2);

    const int num_boxes = grads.dimension(0);
    const int crop_height = grads.dimension(1);
    const int crop_width = grads.dimension(2);
    const int depth = grads.dimension(3);

    int total_count;
    CudaLaunchConfig config;

    // Initialize grads_image with all zeros.
    total_count = batch * image_height * image_width * depth;
    if (total_count > 0) {
      config = GetCudaLaunchConfig(total_count, d);
      SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
          config.virtual_thread_count, grads_image.data());
    }

    // Configurate interpolation method.
    InterpolationMethod method = BILINEAR;
    if (method_name == "nearest") {
      method = NEAREST;
    }

    // Accumulate.
    total_count = num_boxes * crop_height * crop_width * depth;
    if (total_count > 0) {
      config = GetCudaLaunchConfig(total_count, d);
      CropAndResizeBackpropImageKernel<<<
          config.block_count, config.thread_per_block, 0, d.stream()>>>(
          config.virtual_thread_count, grads.data(), boxes.data(),
          box_ind.data(), num_boxes, batch, image_height, image_width,
          crop_height, crop_width, depth, grads_image.data(), method);
    }
    return d.ok();
  }
};

template <typename T>
struct CropAndResizeBackpropBoxes<GPUDevice, T> {
  bool operator()(const GPUDevice& d,
                  typename TTypes<float, 4>::ConstTensor grads,
                  typename TTypes<T, 4>::ConstTensor image,
                  typename TTypes<float, 2>::ConstTensor boxes,
                  typename TTypes<int32, 1>::ConstTensor box_ind,
                  typename TTypes<float, 2>::Tensor grads_boxes) {
    const int batch = image.dimension(0);
    const int image_height = image.dimension(1);
    const int image_width = image.dimension(2);

    const int num_boxes = grads.dimension(0);
    const int crop_height = grads.dimension(1);
    const int crop_width = grads.dimension(2);
    const int depth = grads.dimension(3);

    int total_count;
    CudaLaunchConfig config;

    // Initialize grads_boxes with all zeros.
    total_count = num_boxes * 4;
    if (total_count > 0) {
      config = GetCudaLaunchConfig(total_count, d);
      SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
          config.virtual_thread_count, grads_boxes.data());
    }

    // Accumulate.
    total_count = num_boxes * crop_height * crop_width * depth;
    if (total_count > 0) {
      config = GetCudaLaunchConfig(total_count, d);
      CropAndResizeBackpropBoxesKernel<<<
          config.block_count, config.thread_per_block, 0, d.stream()>>>(
          config.virtual_thread_count, grads.data(), image.data(), boxes.data(),
          box_ind.data(), num_boxes, batch, image_height, image_width,
          crop_height, crop_width, depth, grads_boxes.data());
    }
    return d.ok();
  }
};

#define DEFINE_GPU_SPECS(T)                                 \
  template struct CropAndResize<GPUDevice, T>;              \
  template struct CropAndResizeBackpropImage<GPUDevice, T>; \
  template struct CropAndResizeBackpropBoxes<GPUDevice, T>;

TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);

#undef DEFINE_GPU_SPECS

template struct CheckValidBoxIndexHelper<GPUDevice>;

}  // namespace functor
}  // namespace tensorflow

#endif  // GOOGLE_CUDA