diff options
-rw-r--r-- | tensorflow/core/kernels/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/kernels/adjust_hue_op.cc | 240 | ||||
-rw-r--r-- | tensorflow/core/ops/image_ops.cc | 23 | ||||
-rw-r--r-- | tensorflow/python/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/ops/image_ops.py | 31 | ||||
-rw-r--r-- | tensorflow/python/ops/image_ops_test.py | 135 |
6 files changed, 416 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index a84eaad315..a41bd867c1 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1114,6 +1114,7 @@ tf_kernel_libraries( name = "image", prefixes = [ "adjust_contrast_op", + "adjust_hue_op", "colorspace_op", "crop_and_resize_op", "decode_jpeg_op", diff --git a/tensorflow/core/kernels/adjust_hue_op.cc b/tensorflow/core/kernels/adjust_hue_op.cc new file mode 100644 index 0000000000..98934b4e5b --- /dev/null +++ b/tensorflow/core/kernels/adjust_hue_op.cc @@ -0,0 +1,240 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +class AdjustHueOpBase : public OpKernel { + protected: + AdjustHueOpBase(OpKernelConstruction* context) : OpKernel(context) {} + + struct ComputeOptions { + const Tensor* input; + const Tensor* delta; + Tensor* output; + int64 channel_count; + }; + + virtual void DoCompute(OpKernelContext* context, + const ComputeOptions& options) = 0; + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& delta = context->input(1); + OP_REQUIRES(context, input.dims() >= 3, + errors::InvalidArgument("input must be at least 3-D, got shape", + input.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta.shape()), + errors::InvalidArgument("delta must be scalar: ", + delta.shape().DebugString())); + auto channels = input.dim_size(input.dims() - 1); + OP_REQUIRES( + context, channels == 3, + errors::InvalidArgument("input must have 3 channels but instead has ", + channels, " channels.")); + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &output)); + + if (input.NumElements() > 0) { + const int64 channel_count = input.NumElements() / channels; + ComputeOptions options; + options.input = &input; + options.delta = δ + options.output = output; + options.channel_count = channel_count; + DoCompute(context, options); + } + } +}; + +template <class Device> +class AdjustHueOp; + +namespace internal { +// Helper function to convert a RGB color to H-and-V-range. H is in the range +// of [0, 6] instead of the normal [0, 1] +static void rgb_to_hv_range(float r, float g, float b, float* h, float* v_min, + float* v_max) { + float v_mid; + int h_category; + // According to the figures in: + // https://en.wikipedia.org/wiki/HSL_and_HSV#Hue_and_chroma + // For the conditions, we don't care about the case where two components are + // equal. It is okay to count it in either side in that case. + if (r < g) { + if (b < r) { + // b < r < g + *v_max = g; + v_mid = r; + *v_min = b; + h_category = 1; + } else if (b > g) { + // r < g < b + *v_max = b; + v_mid = g; + *v_min = r; + h_category = 3; + } else { + // r < b < g + *v_max = g; + v_mid = b; + *v_min = r; + h_category = 2; + } + } else { + // g < r + if (b < g) { + // b < g < r + *v_max = r; + v_mid = g; + *v_min = b; + h_category = 0; + } else if (b > r) { + // g < r < b + *v_max = b; + v_mid = r; + *v_min = g; + h_category = 4; + } else { + // g < b < r + *v_max = r; + v_mid = b; + *v_min = g; + h_category = 5; + } + } + if (*v_max == *v_min) { + *h = 0; + return; + } + auto ratio = (v_mid - *v_min) / (*v_max - *v_min); + bool increase = ((h_category & 0x1) == 0); + *h = h_category + (increase ? ratio : (1 - ratio)); +} + +// Helper function to convert from H-and-V-range to RGB. +static void hv_range_to_rgb(float h, float v_min, float v_max, float* r, + float* g, float* b) { + int h_category = static_cast<int>(h); + float ratio = h - h_category; + bool increase = ((h_category & 0x1) == 0); + if (!increase) { + ratio = 1 - ratio; + } + float v_mid = v_min + ratio * (v_max - v_min); + // According to the figures in: + // https://en.wikipedia.org/wiki/HSL_and_HSV#Hue_and_chroma + switch (h_category) { + case 0: + *r = v_max; + *g = v_mid; + *b = v_min; + break; + case 1: + *r = v_mid; + *g = v_max; + *b = v_min; + break; + case 2: + *r = v_min; + *g = v_max; + *b = v_mid; + break; + case 3: + *r = v_min; + *g = v_mid; + *b = v_max; + break; + case 4: + *r = v_mid; + *g = v_min; + *b = v_max; + break; + case 5: + default: + *r = v_max; + *g = v_min; + *b = v_mid; + } +} +} // namespace internal + +template <> +class AdjustHueOp<CPUDevice> : public AdjustHueOpBase { + public: + explicit AdjustHueOp(OpKernelConstruction* context) + : AdjustHueOpBase(context) {} + + void DoCompute(OpKernelContext* context, + const ComputeOptions& options) override { + const Tensor* input = options.input; + const Tensor* delta = options.delta; + Tensor* output = options.output; + const int64 channel_count = options.channel_count; + static const int kChannelSize = 3; + auto input_data = input->shaped<float, 2>({channel_count, kChannelSize}); + const float delta_h = delta->scalar<float>()(); + auto output_data = output->shaped<float, 2>({channel_count, kChannelSize}); + const int kCostPerChannel = 10; + const DeviceBase::CpuWorkerThreads& worker_threads = + *context->device()->tensorflow_cpu_worker_threads(); + Shard(worker_threads.num_threads, worker_threads.workers, channel_count, + kCostPerChannel, [channel_count, &input_data, &output_data, delta_h]( + int64 start_channel, int64 end_channel) { + const float* p = input_data.data() + start_channel * kChannelSize; + float* q = output_data.data() + start_channel * kChannelSize; + for (int i = start_channel; i < end_channel; i++) { + float h, v_min, v_max; + // Convert the RGB color to Hue/V-range. + internal::rgb_to_hv_range(p[0], p[1], p[2], &h, &v_min, &v_max); + static const int kChannelRange = 6; + // Adjust the hue value. And adjust the hue back into the valid + // range of [0, 6). It is faster than a fmod by avoiding + // a float-point division since h is often very close to this + // range. + h += delta_h * kChannelRange; + while (h < 0) { + h += kChannelRange; + } + while (h >= kChannelRange) { + h -= kChannelRange; + } + // Convert the hue and v-range back into RGB. + internal::hv_range_to_rgb(h, v_min, v_max, q, q + 1, q + 2); + p += kChannelSize; + q += kChannelSize; + } + }); + } +}; + +REGISTER_KERNEL_BUILDER(Name("AdjustHue").Device(DEVICE_CPU), + AdjustHueOp<CPUDevice>); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index 626ce4aedb..93d845bdc4 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -441,6 +441,29 @@ output: The contrast-adjusted image or images. )Doc"); // -------------------------------------------------------------------------- +REGISTER_OP("AdjustHue") + .Input("images: float") + .Input("delta: float") + .Output("output: float") + .SetShapeFn([](InferenceContext* c) { + return shape_inference::UnchangedShapeWithRankAtLeast(c, 3); + }) + .Doc(R"Doc( +Adjust the hue of one or more images. + +`images` is a tensor of at least 3 dimensions. The last dimension is +interpretted as channels, and must be three. + +The input image is considered in the RGB colorspace. Conceptually, the RGB +colors are first mapped into HSV. A delta is then applied all the hue values, +and then remapped back to RGB colorspace. + +images: Images to adjust. At least 3-D. +delta: A float delta to add to the hue. +output: The hue-adjusted image or images. +)Doc"); + +// -------------------------------------------------------------------------- REGISTER_OP("DecodePng") .Input("contents: string") .Attr("channels: int = 0") diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 9e96d740a5..7149b6d840 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1557,6 +1557,7 @@ cuda_py_test( ":image_ops", ":io_ops", ":lib", + "//tensorflow:tensorflow_py", ], data = ["//tensorflow/core:image_testdata"], shard_count = 5, diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py index 2836fbabdc..76f4d9fe15 100644 --- a/tensorflow/python/ops/image_ops.py +++ b/tensorflow/python/ops/image_ops.py @@ -164,6 +164,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + from tensorflow.python.framework import common_shapes from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -1043,10 +1045,11 @@ def adjust_gamma(image, gamma=1, gain=1): adjusted_img = (img / scale) ** gamma * scale * gain return adjusted_img - + ops.RegisterShape('AdjustContrast')(common_shapes.call_cpp_shape_fn) ops.RegisterShape('AdjustContrastv2')(common_shapes.call_cpp_shape_fn) +ops.RegisterShape('AdjustHue')(common_shapes.call_cpp_shape_fn) ops.RegisterShape('DrawBoundingBoxes')(common_shapes.call_cpp_shape_fn) ops.RegisterShape('SampleDistortedBoundingBox')(common_shapes.call_cpp_shape_fn) @@ -1265,18 +1268,26 @@ def adjust_hue(image, delta, name=None): orig_dtype = image.dtype flt_image = convert_image_dtype(image, dtypes.float32) - hsv = gen_image_ops.rgb_to_hsv(flt_image) + # TODO(zhengxq): we will switch to the fused version after we add a GPU + # kernel for that. + fused = os.environ.get('TF_ADJUST_HUE_FUSED', '') + fused = fused.lower() in ('true', 't', '1') - hue = array_ops.slice(hsv, [0, 0, 0], [-1, -1, 1]) - saturation = array_ops.slice(hsv, [0, 0, 1], [-1, -1, 1]) - value = array_ops.slice(hsv, [0, 0, 2], [-1, -1, 1]) + if not fused: + hsv = gen_image_ops.rgb_to_hsv(flt_image) - # Note that we add 2*pi to guarantee that the resulting hue is a positive - # floating point number since delta is [-0.5, 0.5]. - hue = math_ops.mod(hue + (delta + 1.), 1.) + hue = array_ops.slice(hsv, [0, 0, 0], [-1, -1, 1]) + saturation = array_ops.slice(hsv, [0, 0, 1], [-1, -1, 1]) + value = array_ops.slice(hsv, [0, 0, 2], [-1, -1, 1]) - hsv_altered = array_ops.concat(2, [hue, saturation, value]) - rgb_altered = gen_image_ops.hsv_to_rgb(hsv_altered) + # Note that we add 2*pi to guarantee that the resulting hue is a positive + # floating point number since delta is [-0.5, 0.5]. + hue = math_ops.mod(hue + (delta + 1.), 1.) + + hsv_altered = array_ops.concat(2, [hue, saturation, value]) + rgb_altered = gen_image_ops.hsv_to_rgb(hsv_altered) + else: + rgb_altered = gen_image_ops.adjust_hue(flt_image, delta) return convert_image_dtype(rgb_altered, orig_dtype) diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index b47ecd2f8f..cc8622f217 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -18,11 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import colorsys import math import os +import time import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin +import tensorflow as tf from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -186,18 +189,18 @@ class AdjustGamma(test_util.TensorFlowTestCase): with self.test_session(): x_data = np.random.uniform(0, 255, (8, 8)) x_np = np.array(x_data, dtype=np.float32) - + x = constant_op.constant(x_np, shape=x_np.shape) y = image_ops.adjust_gamma(x, gamma=0) - + y_tf = y.eval() dtype = x.dtype.as_numpy_dtype y_np = np.array([dtypes.dtype_range[dtype][1]] * x_np.size) y_np = y_np.reshape((8,8)) - + self.assertAllClose(y_tf, y_np, 1e-6) - + def test_adjust_gamma_less_one(self): """Verifying the output with expected results for gamma @@ -215,7 +218,7 @@ class AdjustGamma(test_util.TensorFlowTestCase): [201, 204, 206, 209, 211, 214, 216, 218], [221, 223, 225, 228, 230, 232, 234, 236], [238, 241, 243, 245, 247, 249, 251, 253]], dtype=np.float32) - + self.assertAllClose(y_tf, y_np, 1e-6) def test_adjust_gamma_greater_one(self): @@ -270,6 +273,128 @@ class AdjustHueTest(test_util.TensorFlowTestCase): y_tf = y.eval() self.assertAllEqual(y_tf, y_np) + def _adjustHueNp(self, x_np, delta_h): + self.assertEqual(x_np.shape[-1], 3) + x_v = x_np.reshape([-1, 3]) + y_v = np.ndarray(x_v.shape, dtype=x_v.dtype) + channel_count = x_v.shape[0] + for i in xrange(channel_count): + r = x_v[i][0] + g = x_v[i][1] + b = x_v[i][2] + h, s, v = colorsys.rgb_to_hsv(r, g, b) + h += delta_h + h = math.fmod(h + 10.0, 1.0) + r, g, b = colorsys.hsv_to_rgb(h, s, v) + y_v[i][0] = r + y_v[i][1] = g + y_v[i][2] = b + return y_v.reshape(x_np.shape) + + def _adjustHueTf(self, x_np, delta_h): + with self.test_session(use_gpu=False): + x = constant_op.constant(x_np) + y = image_ops.adjust_hue(x, delta_h) + y_tf = y.eval() + return y_tf + + def testAdjustRandomHue(self): + x_shapes = [ + [2, 2, 3], + [4, 2, 3], + [2, 4, 3], + [2, 5, 3], + [1000, 1, 3], + ] + test_styles = [ + 'all_random', + 'rg_same', + 'rb_same', + 'gb_same', + 'rgb_same', + ] + for x_shape in x_shapes: + for test_style in test_styles: + x_np = np.random.rand(*x_shape) * 255. + delta_h = np.random.rand() * 2.0 - 1.0 + if test_style == 'all_random': + pass + elif test_style == 'rg_same': + x_np[..., 1] = x_np[..., 0] + elif test_style == 'rb_same': + x_np[..., 2] = x_np[..., 0] + elif test_style == 'gb_same': + x_np[..., 2] = x_np[..., 1] + elif test_style == 'rgb_same': + x_np[..., 1] = x_np[..., 0] + x_np[..., 2] = x_np[..., 0] + else: + raise AssertionError('Invalid test style: %s' % (test_style)) + y_np = self._adjustHueNp(x_np, delta_h) + y_tf = self._adjustHueTf(x_np, delta_h) + self.assertAllClose(y_tf, y_np, rtol=2e-5, atol=1e-5) + + def testInvalidShapes(self): + fused = False + if not fused: + # The tests are known to pass with the fused adjust_hue. We will enable + # them when the fused implementation is the default. + return + x_np = np.random.rand(2, 3) * 255. + delta_h = np.random.rand() * 2.0 - 1.0 + fused = False + with self.assertRaisesRegexp(ValueError, 'Shape must be at least rank 3'): + self._adjustHueTf(x_np, delta_h) + x_np = np.random.rand(4, 2, 4) * 255. + delta_h = np.random.rand() * 2.0 - 1.0 + with self.assertRaisesOpError('input must have 3 channels'): + self._adjustHueTf(x_np, delta_h) + + +class AdjustHueBenchmark(test.Benchmark): + + def _benchmarkAdjustHue(self, device, cpu_count): + image_shape = [299, 299, 3] + warmup_rounds = 100 + benchmark_rounds = 1000 + config = tf.ConfigProto() + if cpu_count is not None: + config.inter_op_parallelism_threads = 1 + config.intra_op_parallelism_threads = cpu_count + with tf.Session('', graph=tf.Graph(), config=config) as sess: + with tf.device(device): + inputs = tf.Variable( + tf.random_uniform( + image_shape, dtype=tf.float32) * 255, + trainable=False, + dtype=tf.float32) + delta = tf.constant(0.1, dtype=tf.float32) + outputs = image_ops.adjust_hue(inputs, delta) + run_op = tf.group(outputs) + sess.run(tf.initialize_all_variables()) + for i in xrange(warmup_rounds + benchmark_rounds): + if i == warmup_rounds: + start = time.time() + sess.run(run_op) + end = time.time() + step_time = (end - start) / benchmark_rounds + tag = '%s' % (cpu_count) if cpu_count is not None else '_all' + print('benchmarkAdjustHue_299_299_3_cpu%s step_time: %.2f us' % + (tag, step_time * 1e6)) + self.report_benchmark( + name='benchmarkAdjustHue_299_299_3_cpu%s' % (tag), + iters=benchmark_rounds, + wall_time=step_time) + + def benchmarkAdjustHueCpu1(self): + self._benchmarkAdjustHue('/cpu:0', 1) + + def benchmarkAdjustHueCpuAll(self): + self._benchmarkAdjustHue('/cpu:0', None) + + def benchmarkAdjustHueGpu(self): + self._benchmarkAdjustHue('/gpu:0', None) + class AdjustSaturationTest(test_util.TensorFlowTestCase): |