aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/image
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-04 15:17:41 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-04 15:23:54 -0800
commite12c032397c54f41731d6924d38854eb2bbf5b4d (patch)
treebe671d56843dd31f7b3675b0ded8dd03800c2c9a /tensorflow/contrib/image
parentd87a76d7d6b053219dbd49a87b3c4b379a1c6566 (diff)
hsv_in_yiq gpu implementation.
PiperOrigin-RevId: 177876455
Diffstat (limited to 'tensorflow/contrib/image')
-rwxr-xr-xtensorflow/contrib/image/BUILD39
-rw-r--r--tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc88
-rw-r--r--tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h87
-rw-r--r--tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc84
-rw-r--r--tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_test.cc48
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py10
6 files changed, 294 insertions, 62 deletions
diff --git a/tensorflow/contrib/image/BUILD b/tensorflow/contrib/image/BUILD
index 157e97d237..54502cfc6e 100755
--- a/tensorflow/contrib/image/BUILD
+++ b/tensorflow/contrib/image/BUILD
@@ -9,6 +9,7 @@ package(default_visibility = ["//visibility:public"])
load(
"//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
"tf_custom_op_library",
"tf_gen_op_libs",
"tf_gen_op_wrapper_py",
@@ -106,10 +107,33 @@ tf_custom_op_library(
name = "python/ops/_distort_image_ops.so",
srcs = [
"kernels/adjust_hsv_in_yiq_op.cc",
+ "kernels/adjust_hsv_in_yiq_op.h",
"ops/distort_image_ops.cc",
],
+ gpu_srcs = [
+ "kernels/adjust_hsv_in_yiq_op_gpu.cu.cc",
+ "kernels/adjust_hsv_in_yiq_op.h",
+ ],
deps = [
- "@protobuf_archive//:protobuf",
+ "//tensorflow/core/kernels:gpu_util_hdrs",
+ ],
+)
+
+tf_cc_test(
+ name = "adjust_hsv_in_yiq_op_test",
+ size = "small",
+ srcs = [
+ "kernels/adjust_hsv_in_yiq_op.h",
+ "kernels/adjust_hsv_in_yiq_op_test.cc",
+ ],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_testutil",
+ "//tensorflow/core/kernels:ops_util",
+ "//third_party/eigen3",
],
)
@@ -122,19 +146,6 @@ tf_gen_op_wrapper_py(
deps = [":distort_image_ops_op_lib"],
)
-cc_library(
- name = "distort_image_ops_cc",
- srcs = [
- "kernels/adjust_hsv_in_yiq_op.cc",
- ],
- deps = [
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//third_party/eigen3",
- ],
- alwayslink = 1,
-)
-
py_library(
name = "distort_image_py",
srcs = [
diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc
index f4962ed69d..478b716d88 100644
--- a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc
+++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc
@@ -12,14 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <cmath>
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+#endif
+
+#include "tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h"
#include <memory>
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/work_sharder.h"
@@ -36,10 +37,10 @@ class AdjustHsvInYiqOpBase : public OpKernel {
struct ComputeOptions {
const Tensor* input = nullptr;
+ Tensor* output = nullptr;
const Tensor* delta_h = nullptr;
const Tensor* scale_s = nullptr;
const Tensor* scale_v = nullptr;
- Tensor* output = nullptr;
int64 channel_count = 0;
};
@@ -65,7 +66,7 @@ class AdjustHsvInYiqOpBase : public OpKernel {
scale_v.shape().DebugString()));
auto channels = input.dim_size(input.dims() - 1);
OP_REQUIRES(
- context, channels == 3,
+ context, channels == kChannelSize,
errors::InvalidArgument("input must have 3 channels but instead has ",
channels, " channels."));
@@ -101,53 +102,21 @@ class AdjustHsvInYiqOp<CPUDevice> : public AdjustHsvInYiqOpBase {
const Tensor* input = options.input;
Tensor* output = options.output;
const int64 channel_count = options.channel_count;
- static const int kChannelSize = 3;
auto input_data = input->shaped<float, 2>({channel_count, kChannelSize});
const float delta_h = options.delta_h->scalar<float>()();
const float scale_s = options.scale_s->scalar<float>()();
const float scale_v = options.scale_v->scalar<float>()();
auto output_data = output->shaped<float, 2>({channel_count, kChannelSize});
+ float tranformation_matrix[kChannelSize * kChannelSize] = {0};
+ internal::compute_tranformation_matrix<kChannelSize * kChannelSize>(
+ delta_h, scale_s, scale_v, tranformation_matrix);
const int kCostPerChannel = 10;
const DeviceBase::CpuWorkerThreads& worker_threads =
*context->device()->tensorflow_cpu_worker_threads();
Shard(worker_threads.num_threads, worker_threads.workers, channel_count,
kCostPerChannel,
- [channel_count, &input_data, &output_data, delta_h, scale_s, scale_v](
+ [channel_count, &input_data, &output_data, &tranformation_matrix](
int64 start_channel, int64 end_channel) {
- // Using approximate linear transfomation described in:
- // https://beesbuzz.biz/code/hsv_color_transforms.php
- /** Get the constants from sympy
- from sympy import Matrix
- from sympy.abc import u, w
- # Projection matrix to YIQ. http://en.wikipedia.org/wiki/YIQ
- tyiq = Matrix([[0.299, 0.587, 0.114],
- [0.596, -0.274, -0.322],
- [0.211, -0.523, 0.312]])
- # Hue rotation matrix in YIQ space.
- hue_proj = Matrix(3,3, [v, 0, 0, 0, vsu, -vsw, 0, vsw, vsu])
- m = tyiq.inv() * hue_proj * tyiq
- **/
- // TODO(huangyp): directly compute the projection matrix from tyiq.
- static const float t[kChannelSize][kChannelSize][kChannelSize] = {
- {{.299, .701, .16862179492229},
- {.587, -.587, .329804745287403},
- {.114, -.114, -0.498426540209694}},
- {{.299, -.299, -.327963394172371},
- {.587, .413, .0346106879248821},
- {.114, -.114, .293352706247489}},
- {{.299, -.299, 1.24646136576682},
- {.587, -.587, -1.04322888291964},
- {.114, .886, -.203232482847173}}};
- float m[kChannelSize][kChannelSize] = {{0.}};
- float su = scale_s * std::cos(delta_h);
- float sw = scale_s * std::sin(delta_h);
- for (int q_index = 0; q_index < kChannelSize; q_index++) {
- for (int p_index = 0; p_index < kChannelSize; p_index++) {
- m[q_index][p_index] = scale_v * (t[q_index][p_index][0] +
- t[q_index][p_index][1] * su +
- t[q_index][p_index][2] * sw);
- }
- }
// Applying projection matrix to input RGB vectors.
const float* p = input_data.data() + start_channel * kChannelSize;
float* q = output_data.data() + start_channel * kChannelSize;
@@ -155,7 +124,9 @@ class AdjustHsvInYiqOp<CPUDevice> : public AdjustHsvInYiqOpBase {
for (int q_index = 0; q_index < kChannelSize; q_index++) {
q[q_index] = 0;
for (int p_index = 0; p_index < kChannelSize; p_index++) {
- q[q_index] += m[q_index][p_index] * p[p_index];
+ q[q_index] +=
+ p[p_index] *
+ tranformation_matrix[q_index + kChannelSize * p_index];
}
}
p += kChannelSize;
@@ -165,8 +136,33 @@ class AdjustHsvInYiqOp<CPUDevice> : public AdjustHsvInYiqOpBase {
}
};
-REGISTER_KERNEL_BUILDER(Name("AdjustHsvInYiq").Device(DEVICE_CPU),
- AdjustHsvInYiqOp<CPUDevice>);
+REGISTER_KERNEL_BUILDER(
+ Name("AdjustHsvInYiq").Device(DEVICE_CPU).TypeConstraint<float>("T"),
+ AdjustHsvInYiqOp<CPUDevice>);
+
+#if GOOGLE_CUDA
+template <>
+class AdjustHsvInYiqOp<GPUDevice> : public AdjustHsvInYiqOpBase {
+ public:
+ explicit AdjustHsvInYiqOp(OpKernelConstruction* context)
+ : AdjustHsvInYiqOpBase(context) {}
+
+ void DoCompute(OpKernelContext* ctx, const ComputeOptions& options) override {
+ const int64 number_of_elements = options.input->NumElements();
+ if (number_of_elements <= 0) {
+ return;
+ }
+ const float* delta_h = options.delta_h->flat<float>().data();
+ const float* scale_s = options.scale_s->flat<float>().data();
+ const float* scale_v = options.scale_v->flat<float>().data();
+ functor::AdjustHsvInYiqGPU()(ctx, options.channel_count, options.input,
+ delta_h, scale_s, scale_v, options.output);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("AdjustHsvInYiq").Device(DEVICE_GPU).TypeConstraint<float>("T"),
+ AdjustHsvInYiqOp<GPUDevice>);
+#endif
-// TODO(huangyp): add the GPU kernel
} // namespace tensorflow
diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h
new file mode 100644
index 0000000000..194ae2ba47
--- /dev/null
+++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h
@@ -0,0 +1,87 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_
+
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+#endif // GOOGLE_CUDA
+
+#include <cmath>
+#include "third_party/eigen3/Eigen/Core"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+
+namespace tensorflow {
+
+static constexpr int kChannelSize = 3;
+
+namespace internal {
+
+template <int MATRIX_SIZE>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void compute_tranformation_matrix(
+ const float delta_h, const float scale_s, const float scale_v,
+ float* matrix) {
+ static_assert(MATRIX_SIZE == kChannelSize * kChannelSize,
+ "Size of matrix should be 9.");
+ // Projection matrix from RGB to YIQ. Numbers from wikipedia
+ // https://en.wikipedia.org/wiki/YIQ
+ Eigen::Matrix3f yiq;
+ /* clang-format off */
+ yiq << 0.299, 0.587, 0.114,
+ 0.596, -0.274, -0.322,
+ 0.211, -0.523, 0.312;
+ Eigen::Matrix3f yiq_inverse;
+ yiq_inverse << 1, 0.95617069, 0.62143257,
+ 1, -0.2726886, -0.64681324,
+ 1, -1.103744, 1.70062309;
+ /* clang-format on */
+ // Construct hsv linear transformation matrix in YIQ space.
+ // https://beesbuzz.biz/code/hsv_color_transforms.php
+ float vsu = scale_v * scale_s * std::cos(delta_h);
+ float vsw = scale_v * scale_s * std::sin(delta_h);
+ Eigen::Matrix3f hsv_transform;
+ /* clang-format off */
+ hsv_transform << scale_v, 0, 0,
+ 0, vsu, -vsw,
+ 0, vsw, vsu;
+ /* clang-format on */
+ // Compute final transformation matrix = inverse_yiq * hsv_transform * yiq
+ Eigen::Map<Eigen::Matrix<float, 3, 3, Eigen::ColMajor>> eigen_matrix(matrix);
+ eigen_matrix = yiq_inverse * hsv_transform * yiq;
+}
+} // namespace internal
+
+#if GOOGLE_CUDA
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+
+struct AdjustHsvInYiqGPU {
+ void operator()(OpKernelContext* ctx, int channel_count,
+ const Tensor* const input, const float* const delta_h,
+ const float* const scale_s, const float* const scale_v,
+ Tensor* const output);
+};
+
+} // namespace functor
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_
diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc
new file mode 100644
index 0000000000..b71ff9cd50
--- /dev/null
+++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc
@@ -0,0 +1,84 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h"
+#include "tensorflow/core/kernels/gpu_utils.h"
+#include "tensorflow/core/platform/stream_executor.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+namespace tensorflow {
+
+namespace internal {
+
+__global__ void compute_tranformation_matrix_cuda(const float* const delta_h,
+ const float* const scale_s,
+ const float* const scale_v,
+ float* const matrix,
+ const int matrix_size) {
+ if (matrix_size == kChannelSize * kChannelSize) {
+ compute_tranformation_matrix<kChannelSize * kChannelSize>(
+ *delta_h, *scale_s, *scale_v, matrix);
+ }
+}
+} // namespace internal
+
+namespace functor {
+
+void AdjustHsvInYiqGPU::operator()(OpKernelContext* ctx, int channel_count,
+ const Tensor* const input,
+ const float* const delta_h,
+ const float* const scale_s,
+ const float* const scale_v,
+ Tensor* const output) {
+ const uint64 m = channel_count;
+ const uint64 k = kChannelSize;
+ const uint64 n = kChannelSize;
+ auto* cu_stream = ctx->eigen_device<GPUDevice>().stream();
+ OP_REQUIRES(ctx, cu_stream, errors::Internal("No GPU stream available."));
+ Tensor tranformation_matrix;
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(
+ DT_FLOAT, TensorShape({kChannelSize * kChannelSize}),
+ &tranformation_matrix));
+ // TODO(huangyp): It takes about 3.5 us to comute tranformation_matrix
+ // with one thread. Improve its performance if necessary.
+ internal::compute_tranformation_matrix_cuda<<<1, 1, 0, cu_stream>>>(
+ delta_h, scale_s, scale_v, tranformation_matrix.flat<float>().data(),
+ tranformation_matrix.flat<float>().size());
+ // Call cuBlas C = A * B directly.
+ auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
+ auto a_ptr =
+ AsDeviceMemory(input->flat<float>().data(), input->flat<float>().size());
+ auto b_ptr = AsDeviceMemory(tranformation_matrix.flat<float>().data(),
+ tranformation_matrix.flat<float>().size());
+ auto c_ptr = AsDeviceMemory(output->flat<float>().data(),
+ output->flat<float>().size());
+ auto* stream = ctx->op_device_context()->stream();
+ OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
+ // TODO(huangyp): share/use autotune cublas algorithms in Matmul.op.
+ bool blas_launch_status =
+ stream
+ ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr, n,
+ a_ptr, k, 0.0f, &c_ptr, n)
+ .ok();
+ if (!blas_launch_status) {
+ ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
+ ", n=", n, ", k=", k));
+ }
+}
+} // namespace functor
+} // namespace tensorflow
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_test.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_test.cc
new file mode 100644
index 0000000000..4cbbd27784
--- /dev/null
+++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_test.cc
@@ -0,0 +1,48 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+class AdjustHsvInYiqOpTest : public OpsTestBase {
+ protected:
+};
+
+TEST_F(AdjustHsvInYiqOpTest, IdentiyTransformMatrix) {
+ Tensor matrix(allocator(), DT_FLOAT, TensorShape({9}));
+ internal::compute_tranformation_matrix<9>(0.0, 1.0, 1.0,
+ matrix.flat<float>().data());
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({9}));
+ test::FillValues<float>(&expected, {1, 0, 0, 0, 1, 0, 0, 0, 1});
+ test::ExpectClose(matrix, expected);
+}
+
+TEST_F(AdjustHsvInYiqOpTest, ScaleValueTransformMatrix) {
+ float scale_v = 2.3;
+ Tensor matrix(allocator(), DT_FLOAT, TensorShape({9}));
+ internal::compute_tranformation_matrix<9>(0.0, 1.0, scale_v,
+ matrix.flat<float>().data());
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({9}));
+ test::FillValues<float>(&expected,
+ {scale_v, 0, 0, 0, scale_v, 0, 0, 0, scale_v});
+ test::ExpectClose(matrix, expected);
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py
index b85f19d29b..a495b58b7f 100644
--- a/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py
+++ b/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py
@@ -172,7 +172,7 @@ class AdjustValueInYiqTest(test_util.TensorFlowTestCase):
raise AssertionError('Invalid test style: %s' % (test_style))
y_np = self._adjust_value_in_yiq_np(x_np, scale)
y_tf = self._adjust_value_in_yiq_tf(x_np, scale)
- self.assertAllClose(y_tf, y_np, rtol=2e-5, atol=1e-5)
+ self.assertAllClose(y_tf, y_np, rtol=2e-4, atol=1e-4)
def test_invalid_shapes(self):
x_np = np.random.rand(2, 3) * 255.
@@ -237,7 +237,7 @@ class AdjustSaturationInYiqTest(test_util.TensorFlowTestCase):
raise AssertionError('Invalid test style: %s' % (test_style))
y_baseline = self._adjust_saturation_in_yiq_np(x_np, scale)
y_tf = self._adjust_saturation_in_yiq_tf(x_np, scale)
- self.assertAllClose(y_tf, y_baseline, rtol=2e-5, atol=1e-5)
+ self.assertAllClose(y_tf, y_baseline, rtol=2e-4, atol=1e-4)
def test_invalid_shapes(self):
x_np = np.random.rand(2, 3) * 255.
@@ -291,6 +291,9 @@ class AdjustHueInYiqBenchmark(test.Benchmark):
def benchmark_adjust_hue_in_yiqCpuAll(self):
self._benchmark_adjust_hue_in_yiq('/cpu:0', None)
+ def benchmark_adjust_hue_in_yiq_gpu_all(self):
+ self._benchmark_adjust_hue_in_yiq(test.gpu_device_name(), None)
+
class AdjustSaturationInYiqBenchmark(test.Benchmark):
@@ -333,6 +336,9 @@ class AdjustSaturationInYiqBenchmark(test.Benchmark):
def benchmark_adjust_saturation_in_yiq_cpu_all(self):
self._benchmark_adjust_saturation_in_yiq('/cpu:0', None)
+ def benchmark_adjust_saturation_in_yiq_gpu_all(self):
+ self._benchmark_adjust_saturation_in_yiq(test.gpu_device_name(), None)
+
if __name__ == '__main__':
googletest.main()