aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/adjust_hue_op_gpu.cu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/adjust_hue_op_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/adjust_hue_op_gpu.cu.cc141
1 files changed, 141 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/adjust_hue_op_gpu.cu.cc b/tensorflow/core/kernels/adjust_hue_op_gpu.cu.cc
new file mode 100644
index 0000000000..2fc69ed101
--- /dev/null
+++ b/tensorflow/core/kernels/adjust_hue_op_gpu.cu.cc
@@ -0,0 +1,141 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/adjust_hue_op.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+namespace tensorflow {
+namespace internal {
+
+namespace {
+ typedef struct RgbTuple {
+ float r;
+ float g;
+ float b;
+ } RgbTuple;
+
+ typedef struct HsvTuple {
+ float h;
+ float s;
+ float v;
+ } HsvTuple;
+} // anon namespace
+
+__device__ HsvTuple rgb2hsv_cuda(const float r, const float g, const float b)
+{
+ HsvTuple tuple;
+ const float M = fmaxf(r, fmaxf(g, b));
+ const float m = fminf(r, fminf(g, b));
+ const float chroma = M - m;
+ float h = 0.0f, s = 0.0f;
+ // hue
+ if (chroma > 0.0f) {
+ if (M == r) {
+ const float num = (g - b) / chroma;
+ const float sign = copysignf(1.0f, num);
+ h = ((sign < 0.0f) * 6.0f + sign * fmodf(sign * num, 6.0f)) / 6.0f;
+ } else if (M == g) {
+ h = ((b - r) / chroma + 2.0f) / 6.0f;
+ } else {
+ h = ((r - g) / chroma + 4.0f) / 6.0f;
+ }
+ } else {
+ h = 0.0f;
+ }
+ // saturation
+ if (M > 0.0) {
+ s = chroma / M;
+ } else {
+ s = 0.0f;
+ }
+ tuple.h = h;
+ tuple.s = s;
+ tuple.v = M;
+ return tuple;
+}
+
+__device__ RgbTuple hsv2rgb_cuda(const float h, const float s, const float v)
+{
+ RgbTuple tuple;
+ const float new_h = h * 6.0f;
+ const float chroma = v * s;
+ const float x = chroma * (1.0f - fabsf(fmodf(new_h, 2.0f) - 1.0f));
+ const float new_m = v - chroma;
+ const bool between_0_and_1 = new_h >= 0.0f && new_h < 1.0f;
+ const bool between_1_and_2 = new_h >= 1.0f && new_h < 2.0f;
+ const bool between_2_and_3 = new_h >= 2.0f && new_h < 3.0f;
+ const bool between_3_and_4 = new_h >= 3.0f && new_h < 4.0f;
+ const bool between_4_and_5 = new_h >= 4.0f && new_h < 5.0f;
+ const bool between_5_and_6 = new_h >= 5.0f && new_h < 6.0f;
+ tuple.r = chroma * (between_0_and_1 || between_5_and_6) +
+ x * (between_1_and_2 || between_4_and_5) + new_m;
+ tuple.g = chroma * (between_1_and_2 || between_2_and_3) +
+ x * (between_0_and_1 || between_3_and_4) + new_m;
+ tuple.b = chroma * (between_3_and_4 || between_4_and_5) +
+ x * (between_2_and_3 || between_5_and_6) + new_m;
+ return tuple;
+}
+
+__global__ void adjust_hue_nhwc(const int64 number_elements,
+ const float * const __restrict__ input,
+ float * const output,
+ const float * const hue_delta)
+{
+ // multiply by 3 since we're dealing with contiguous RGB bytes for each pixel (NHWC)
+ const int64 idx = (blockDim.x * blockIdx.x + threadIdx.x) * 3;
+ // bounds check
+ if (idx > number_elements - 1) {
+ return;
+ }
+ const float delta = hue_delta[0];
+ const HsvTuple hsv = rgb2hsv_cuda(input[idx], input[idx + 1], input[idx + 2]);
+ // hue adjustment
+ float new_h = fmodf(hsv.h + delta, 1.0f);
+ if (new_h < 0.0f) {
+ new_h = fmodf(1.0f + new_h, 1.0f);
+ }
+ const RgbTuple rgb = hsv2rgb_cuda(new_h, hsv.s, hsv.v);
+ output[idx] = rgb.r;
+ output[idx + 1] = rgb.g;
+ output[idx + 2] = rgb.b;
+}
+} // namespace internal
+
+
+namespace functor {
+
+void AdjustHueGPU::operator()(
+ GPUDevice* device,
+ const int64 number_of_elements,
+ const float* const input,
+ const float* const delta,
+ float* const output
+) {
+ const auto stream = device->stream();
+ const CudaLaunchConfig config = GetCudaLaunchConfig(number_of_elements, *device);
+ const int threads_per_block = config.thread_per_block;
+ const int block_count = (number_of_elements + threads_per_block - 1) / threads_per_block;
+ internal::adjust_hue_nhwc<<<block_count, threads_per_block, 0, stream>>>(
+ number_of_elements, input, output, delta
+ );
+}
+} // namespace functor
+} // namespace tensorflow
+#endif // GOOGLE_CUDA