aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/adjust_hue_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/adjust_hue_op.cc')
-rw-r--r--tensorflow/core/kernels/adjust_hue_op.cc43
1 files changed, 41 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/adjust_hue_op.cc b/tensorflow/core/kernels/adjust_hue_op.cc
index 09300737c7..e8f32693f7 100644
--- a/tensorflow/core/kernels/adjust_hue_op.cc
+++ b/tensorflow/core/kernels/adjust_hue_op.cc
@@ -1,5 +1,4 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
@@ -12,16 +11,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#define EIGEN_USE_THREADS
+
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+#endif
+
#include <memory>
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/adjust_hue_op.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/work_sharder.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
namespace tensorflow {
@@ -77,6 +84,7 @@ template <class Device>
class AdjustHueOp;
namespace internal {
+
// Helper function to convert a RGB color to H-and-V-range. H is in the range
// of [0, 6] instead of the normal [0, 1]
static void rgb_to_hv_range(float r, float g, float b, float* h, float* v_min,
@@ -185,6 +193,7 @@ static void hv_range_to_rgb(float h, float v_min, float v_max, float* r,
}
} // namespace internal
+
template <>
class AdjustHueOp<CPUDevice> : public AdjustHueOpBase {
public:
@@ -237,4 +246,34 @@ class AdjustHueOp<CPUDevice> : public AdjustHueOpBase {
REGISTER_KERNEL_BUILDER(Name("AdjustHue").Device(DEVICE_CPU),
AdjustHueOp<CPUDevice>);
+#if GOOGLE_CUDA
+template <>
+class AdjustHueOp<GPUDevice> : public AdjustHueOpBase {
+ public:
+ explicit AdjustHueOp(OpKernelConstruction* context)
+ : AdjustHueOpBase(context) {}
+
+ virtual void DoCompute(OpKernelContext* context, const ComputeOptions& options) override {
+ const Tensor* input = options.input;
+ const Tensor* delta = options.delta;
+ Tensor* output = options.output;
+ const int64 number_of_elements = input->NumElements();
+ GPUDevice device = context->eigen_gpu_device();
+ const auto stream = device.stream();
+ OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
+ if (number_of_elements > 0) {
+ const float* input_data = input->flat<float>().data();
+ const float* delta_h = delta->flat<float>().data();
+ float* const output_data = output->flat<float>().data();
+ functor::AdjustHueGPU()(&device, number_of_elements, input_data, delta_h,
+ output_data);
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("AdjustHue").Device(DEVICE_GPU), AdjustHueOp<GPUDevice>);
+
+#endif
+
+//} // namespace functor
} // namespace tensorflow