diff options
Diffstat (limited to 'tensorflow/core/kernels/adjust_hue_op.cc')
-rw-r--r-- | tensorflow/core/kernels/adjust_hue_op.cc | 43 |
1 files changed, 41 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/adjust_hue_op.cc b/tensorflow/core/kernels/adjust_hue_op.cc index 09300737c7..e8f32693f7 100644 --- a/tensorflow/core/kernels/adjust_hue_op.cc +++ b/tensorflow/core/kernels/adjust_hue_op.cc @@ -1,5 +1,4 @@ /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -12,16 +11,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#define EIGEN_USE_THREADS + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif + #include <memory> -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/adjust_hue_op.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/work_sharder.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { @@ -77,6 +84,7 @@ template <class Device> class AdjustHueOp; namespace internal { + // Helper function to convert a RGB color to H-and-V-range. H is in the range // of [0, 6] instead of the normal [0, 1] static void rgb_to_hv_range(float r, float g, float b, float* h, float* v_min, @@ -185,6 +193,7 @@ static void hv_range_to_rgb(float h, float v_min, float v_max, float* r, } } // namespace internal + template <> class AdjustHueOp<CPUDevice> : public AdjustHueOpBase { public: @@ -237,4 +246,34 @@ class AdjustHueOp<CPUDevice> : public AdjustHueOpBase { REGISTER_KERNEL_BUILDER(Name("AdjustHue").Device(DEVICE_CPU), AdjustHueOp<CPUDevice>); +#if GOOGLE_CUDA +template <> +class AdjustHueOp<GPUDevice> : public AdjustHueOpBase { + public: + explicit AdjustHueOp(OpKernelConstruction* context) + : AdjustHueOpBase(context) {} + + virtual void DoCompute(OpKernelContext* context, const ComputeOptions& options) override { + const Tensor* input = options.input; + const Tensor* delta = options.delta; + Tensor* output = options.output; + const int64 number_of_elements = input->NumElements(); + GPUDevice device = context->eigen_gpu_device(); + const auto stream = device.stream(); + OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); + if (number_of_elements > 0) { + const float* input_data = input->flat<float>().data(); + const float* delta_h = delta->flat<float>().data(); + float* const output_data = output->flat<float>().data(); + functor::AdjustHueGPU()(&device, number_of_elements, input_data, delta_h, + output_data); + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("AdjustHue").Device(DEVICE_GPU), AdjustHueOp<GPUDevice>); + +#endif + +//} // namespace functor } // namespace tensorflow |