aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/BUILD1
-rw-r--r--tensorflow/core/kernels/adjust_hue_op.cc240
-rw-r--r--tensorflow/core/ops/image_ops.cc23
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/ops/image_ops.py31
-rw-r--r--tensorflow/python/ops/image_ops_test.py135
6 files changed, 416 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index a84eaad315..a41bd867c1 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1114,6 +1114,7 @@ tf_kernel_libraries(
name = "image",
prefixes = [
"adjust_contrast_op",
+ "adjust_hue_op",
"colorspace_op",
"crop_and_resize_op",
"decode_jpeg_op",
diff --git a/tensorflow/core/kernels/adjust_hue_op.cc b/tensorflow/core/kernels/adjust_hue_op.cc
new file mode 100644
index 0000000000..98934b4e5b
--- /dev/null
+++ b/tensorflow/core/kernels/adjust_hue_op.cc
@@ -0,0 +1,240 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+class AdjustHueOpBase : public OpKernel {
+ protected:
+ AdjustHueOpBase(OpKernelConstruction* context) : OpKernel(context) {}
+
+ struct ComputeOptions {
+ const Tensor* input;
+ const Tensor* delta;
+ Tensor* output;
+ int64 channel_count;
+ };
+
+ virtual void DoCompute(OpKernelContext* context,
+ const ComputeOptions& options) = 0;
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const Tensor& delta = context->input(1);
+ OP_REQUIRES(context, input.dims() >= 3,
+ errors::InvalidArgument("input must be at least 3-D, got shape",
+ input.shape().DebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta.shape()),
+ errors::InvalidArgument("delta must be scalar: ",
+ delta.shape().DebugString()));
+ auto channels = input.dim_size(input.dims() - 1);
+ OP_REQUIRES(
+ context, channels == 3,
+ errors::InvalidArgument("input must have 3 channels but instead has ",
+ channels, " channels."));
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &output));
+
+ if (input.NumElements() > 0) {
+ const int64 channel_count = input.NumElements() / channels;
+ ComputeOptions options;
+ options.input = &input;
+ options.delta = &delta;
+ options.output = output;
+ options.channel_count = channel_count;
+ DoCompute(context, options);
+ }
+ }
+};
+
+template <class Device>
+class AdjustHueOp;
+
+namespace internal {
+// Helper function to convert a RGB color to H-and-V-range. H is in the range
+// of [0, 6] instead of the normal [0, 1]
+static void rgb_to_hv_range(float r, float g, float b, float* h, float* v_min,
+ float* v_max) {
+ float v_mid;
+ int h_category;
+ // According to the figures in:
+ // https://en.wikipedia.org/wiki/HSL_and_HSV#Hue_and_chroma
+ // For the conditions, we don't care about the case where two components are
+ // equal. It is okay to count it in either side in that case.
+ if (r < g) {
+ if (b < r) {
+ // b < r < g
+ *v_max = g;
+ v_mid = r;
+ *v_min = b;
+ h_category = 1;
+ } else if (b > g) {
+ // r < g < b
+ *v_max = b;
+ v_mid = g;
+ *v_min = r;
+ h_category = 3;
+ } else {
+ // r < b < g
+ *v_max = g;
+ v_mid = b;
+ *v_min = r;
+ h_category = 2;
+ }
+ } else {
+ // g < r
+ if (b < g) {
+ // b < g < r
+ *v_max = r;
+ v_mid = g;
+ *v_min = b;
+ h_category = 0;
+ } else if (b > r) {
+ // g < r < b
+ *v_max = b;
+ v_mid = r;
+ *v_min = g;
+ h_category = 4;
+ } else {
+ // g < b < r
+ *v_max = r;
+ v_mid = b;
+ *v_min = g;
+ h_category = 5;
+ }
+ }
+ if (*v_max == *v_min) {
+ *h = 0;
+ return;
+ }
+ auto ratio = (v_mid - *v_min) / (*v_max - *v_min);
+ bool increase = ((h_category & 0x1) == 0);
+ *h = h_category + (increase ? ratio : (1 - ratio));
+}
+
+// Helper function to convert from H-and-V-range to RGB.
+static void hv_range_to_rgb(float h, float v_min, float v_max, float* r,
+ float* g, float* b) {
+ int h_category = static_cast<int>(h);
+ float ratio = h - h_category;
+ bool increase = ((h_category & 0x1) == 0);
+ if (!increase) {
+ ratio = 1 - ratio;
+ }
+ float v_mid = v_min + ratio * (v_max - v_min);
+ // According to the figures in:
+ // https://en.wikipedia.org/wiki/HSL_and_HSV#Hue_and_chroma
+ switch (h_category) {
+ case 0:
+ *r = v_max;
+ *g = v_mid;
+ *b = v_min;
+ break;
+ case 1:
+ *r = v_mid;
+ *g = v_max;
+ *b = v_min;
+ break;
+ case 2:
+ *r = v_min;
+ *g = v_max;
+ *b = v_mid;
+ break;
+ case 3:
+ *r = v_min;
+ *g = v_mid;
+ *b = v_max;
+ break;
+ case 4:
+ *r = v_mid;
+ *g = v_min;
+ *b = v_max;
+ break;
+ case 5:
+ default:
+ *r = v_max;
+ *g = v_min;
+ *b = v_mid;
+ }
+}
+} // namespace internal
+
+template <>
+class AdjustHueOp<CPUDevice> : public AdjustHueOpBase {
+ public:
+ explicit AdjustHueOp(OpKernelConstruction* context)
+ : AdjustHueOpBase(context) {}
+
+ void DoCompute(OpKernelContext* context,
+ const ComputeOptions& options) override {
+ const Tensor* input = options.input;
+ const Tensor* delta = options.delta;
+ Tensor* output = options.output;
+ const int64 channel_count = options.channel_count;
+ static const int kChannelSize = 3;
+ auto input_data = input->shaped<float, 2>({channel_count, kChannelSize});
+ const float delta_h = delta->scalar<float>()();
+ auto output_data = output->shaped<float, 2>({channel_count, kChannelSize});
+ const int kCostPerChannel = 10;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, channel_count,
+ kCostPerChannel, [channel_count, &input_data, &output_data, delta_h](
+ int64 start_channel, int64 end_channel) {
+ const float* p = input_data.data() + start_channel * kChannelSize;
+ float* q = output_data.data() + start_channel * kChannelSize;
+ for (int i = start_channel; i < end_channel; i++) {
+ float h, v_min, v_max;
+ // Convert the RGB color to Hue/V-range.
+ internal::rgb_to_hv_range(p[0], p[1], p[2], &h, &v_min, &v_max);
+ static const int kChannelRange = 6;
+ // Adjust the hue value. And adjust the hue back into the valid
+ // range of [0, 6). It is faster than a fmod by avoiding
+ // a float-point division since h is often very close to this
+ // range.
+ h += delta_h * kChannelRange;
+ while (h < 0) {
+ h += kChannelRange;
+ }
+ while (h >= kChannelRange) {
+ h -= kChannelRange;
+ }
+ // Convert the hue and v-range back into RGB.
+ internal::hv_range_to_rgb(h, v_min, v_max, q, q + 1, q + 2);
+ p += kChannelSize;
+ q += kChannelSize;
+ }
+ });
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("AdjustHue").Device(DEVICE_CPU),
+ AdjustHueOp<CPUDevice>);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index 626ce4aedb..93d845bdc4 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -441,6 +441,29 @@ output: The contrast-adjusted image or images.
)Doc");
// --------------------------------------------------------------------------
+REGISTER_OP("AdjustHue")
+ .Input("images: float")
+ .Input("delta: float")
+ .Output("output: float")
+ .SetShapeFn([](InferenceContext* c) {
+ return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
+ })
+ .Doc(R"Doc(
+Adjust the hue of one or more images.
+
+`images` is a tensor of at least 3 dimensions. The last dimension is
+interpretted as channels, and must be three.
+
+The input image is considered in the RGB colorspace. Conceptually, the RGB
+colors are first mapped into HSV. A delta is then applied all the hue values,
+and then remapped back to RGB colorspace.
+
+images: Images to adjust. At least 3-D.
+delta: A float delta to add to the hue.
+output: The hue-adjusted image or images.
+)Doc");
+
+// --------------------------------------------------------------------------
REGISTER_OP("DecodePng")
.Input("contents: string")
.Attr("channels: int = 0")
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 9e96d740a5..7149b6d840 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1557,6 +1557,7 @@ cuda_py_test(
":image_ops",
":io_ops",
":lib",
+ "//tensorflow:tensorflow_py",
],
data = ["//tensorflow/core:image_testdata"],
shard_count = 5,
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
index 2836fbabdc..76f4d9fe15 100644
--- a/tensorflow/python/ops/image_ops.py
+++ b/tensorflow/python/ops/image_ops.py
@@ -164,6 +164,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -1043,10 +1045,11 @@ def adjust_gamma(image, gamma=1, gain=1):
adjusted_img = (img / scale) ** gamma * scale * gain
return adjusted_img
-
+
ops.RegisterShape('AdjustContrast')(common_shapes.call_cpp_shape_fn)
ops.RegisterShape('AdjustContrastv2')(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape('AdjustHue')(common_shapes.call_cpp_shape_fn)
ops.RegisterShape('DrawBoundingBoxes')(common_shapes.call_cpp_shape_fn)
ops.RegisterShape('SampleDistortedBoundingBox')(common_shapes.call_cpp_shape_fn)
@@ -1265,18 +1268,26 @@ def adjust_hue(image, delta, name=None):
orig_dtype = image.dtype
flt_image = convert_image_dtype(image, dtypes.float32)
- hsv = gen_image_ops.rgb_to_hsv(flt_image)
+ # TODO(zhengxq): we will switch to the fused version after we add a GPU
+ # kernel for that.
+ fused = os.environ.get('TF_ADJUST_HUE_FUSED', '')
+ fused = fused.lower() in ('true', 't', '1')
- hue = array_ops.slice(hsv, [0, 0, 0], [-1, -1, 1])
- saturation = array_ops.slice(hsv, [0, 0, 1], [-1, -1, 1])
- value = array_ops.slice(hsv, [0, 0, 2], [-1, -1, 1])
+ if not fused:
+ hsv = gen_image_ops.rgb_to_hsv(flt_image)
- # Note that we add 2*pi to guarantee that the resulting hue is a positive
- # floating point number since delta is [-0.5, 0.5].
- hue = math_ops.mod(hue + (delta + 1.), 1.)
+ hue = array_ops.slice(hsv, [0, 0, 0], [-1, -1, 1])
+ saturation = array_ops.slice(hsv, [0, 0, 1], [-1, -1, 1])
+ value = array_ops.slice(hsv, [0, 0, 2], [-1, -1, 1])
- hsv_altered = array_ops.concat(2, [hue, saturation, value])
- rgb_altered = gen_image_ops.hsv_to_rgb(hsv_altered)
+ # Note that we add 2*pi to guarantee that the resulting hue is a positive
+ # floating point number since delta is [-0.5, 0.5].
+ hue = math_ops.mod(hue + (delta + 1.), 1.)
+
+ hsv_altered = array_ops.concat(2, [hue, saturation, value])
+ rgb_altered = gen_image_ops.hsv_to_rgb(hsv_altered)
+ else:
+ rgb_altered = gen_image_ops.adjust_hue(flt_image, delta)
return convert_image_dtype(rgb_altered, orig_dtype)
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index b47ecd2f8f..cc8622f217 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -18,11 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import colorsys
import math
import os
+import time
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
+import tensorflow as tf
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -186,18 +189,18 @@ class AdjustGamma(test_util.TensorFlowTestCase):
with self.test_session():
x_data = np.random.uniform(0, 255, (8, 8))
x_np = np.array(x_data, dtype=np.float32)
-
+
x = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.adjust_gamma(x, gamma=0)
-
+
y_tf = y.eval()
dtype = x.dtype.as_numpy_dtype
y_np = np.array([dtypes.dtype_range[dtype][1]] * x_np.size)
y_np = y_np.reshape((8,8))
-
+
self.assertAllClose(y_tf, y_np, 1e-6)
-
+
def test_adjust_gamma_less_one(self):
"""Verifying the output with expected results for gamma
@@ -215,7 +218,7 @@ class AdjustGamma(test_util.TensorFlowTestCase):
[201, 204, 206, 209, 211, 214, 216, 218],
[221, 223, 225, 228, 230, 232, 234, 236],
[238, 241, 243, 245, 247, 249, 251, 253]], dtype=np.float32)
-
+
self.assertAllClose(y_tf, y_np, 1e-6)
def test_adjust_gamma_greater_one(self):
@@ -270,6 +273,128 @@ class AdjustHueTest(test_util.TensorFlowTestCase):
y_tf = y.eval()
self.assertAllEqual(y_tf, y_np)
+ def _adjustHueNp(self, x_np, delta_h):
+ self.assertEqual(x_np.shape[-1], 3)
+ x_v = x_np.reshape([-1, 3])
+ y_v = np.ndarray(x_v.shape, dtype=x_v.dtype)
+ channel_count = x_v.shape[0]
+ for i in xrange(channel_count):
+ r = x_v[i][0]
+ g = x_v[i][1]
+ b = x_v[i][2]
+ h, s, v = colorsys.rgb_to_hsv(r, g, b)
+ h += delta_h
+ h = math.fmod(h + 10.0, 1.0)
+ r, g, b = colorsys.hsv_to_rgb(h, s, v)
+ y_v[i][0] = r
+ y_v[i][1] = g
+ y_v[i][2] = b
+ return y_v.reshape(x_np.shape)
+
+ def _adjustHueTf(self, x_np, delta_h):
+ with self.test_session(use_gpu=False):
+ x = constant_op.constant(x_np)
+ y = image_ops.adjust_hue(x, delta_h)
+ y_tf = y.eval()
+ return y_tf
+
+ def testAdjustRandomHue(self):
+ x_shapes = [
+ [2, 2, 3],
+ [4, 2, 3],
+ [2, 4, 3],
+ [2, 5, 3],
+ [1000, 1, 3],
+ ]
+ test_styles = [
+ 'all_random',
+ 'rg_same',
+ 'rb_same',
+ 'gb_same',
+ 'rgb_same',
+ ]
+ for x_shape in x_shapes:
+ for test_style in test_styles:
+ x_np = np.random.rand(*x_shape) * 255.
+ delta_h = np.random.rand() * 2.0 - 1.0
+ if test_style == 'all_random':
+ pass
+ elif test_style == 'rg_same':
+ x_np[..., 1] = x_np[..., 0]
+ elif test_style == 'rb_same':
+ x_np[..., 2] = x_np[..., 0]
+ elif test_style == 'gb_same':
+ x_np[..., 2] = x_np[..., 1]
+ elif test_style == 'rgb_same':
+ x_np[..., 1] = x_np[..., 0]
+ x_np[..., 2] = x_np[..., 0]
+ else:
+ raise AssertionError('Invalid test style: %s' % (test_style))
+ y_np = self._adjustHueNp(x_np, delta_h)
+ y_tf = self._adjustHueTf(x_np, delta_h)
+ self.assertAllClose(y_tf, y_np, rtol=2e-5, atol=1e-5)
+
+ def testInvalidShapes(self):
+ fused = False
+ if not fused:
+ # The tests are known to pass with the fused adjust_hue. We will enable
+ # them when the fused implementation is the default.
+ return
+ x_np = np.random.rand(2, 3) * 255.
+ delta_h = np.random.rand() * 2.0 - 1.0
+ fused = False
+ with self.assertRaisesRegexp(ValueError, 'Shape must be at least rank 3'):
+ self._adjustHueTf(x_np, delta_h)
+ x_np = np.random.rand(4, 2, 4) * 255.
+ delta_h = np.random.rand() * 2.0 - 1.0
+ with self.assertRaisesOpError('input must have 3 channels'):
+ self._adjustHueTf(x_np, delta_h)
+
+
+class AdjustHueBenchmark(test.Benchmark):
+
+ def _benchmarkAdjustHue(self, device, cpu_count):
+ image_shape = [299, 299, 3]
+ warmup_rounds = 100
+ benchmark_rounds = 1000
+ config = tf.ConfigProto()
+ if cpu_count is not None:
+ config.inter_op_parallelism_threads = 1
+ config.intra_op_parallelism_threads = cpu_count
+ with tf.Session('', graph=tf.Graph(), config=config) as sess:
+ with tf.device(device):
+ inputs = tf.Variable(
+ tf.random_uniform(
+ image_shape, dtype=tf.float32) * 255,
+ trainable=False,
+ dtype=tf.float32)
+ delta = tf.constant(0.1, dtype=tf.float32)
+ outputs = image_ops.adjust_hue(inputs, delta)
+ run_op = tf.group(outputs)
+ sess.run(tf.initialize_all_variables())
+ for i in xrange(warmup_rounds + benchmark_rounds):
+ if i == warmup_rounds:
+ start = time.time()
+ sess.run(run_op)
+ end = time.time()
+ step_time = (end - start) / benchmark_rounds
+ tag = '%s' % (cpu_count) if cpu_count is not None else '_all'
+ print('benchmarkAdjustHue_299_299_3_cpu%s step_time: %.2f us' %
+ (tag, step_time * 1e6))
+ self.report_benchmark(
+ name='benchmarkAdjustHue_299_299_3_cpu%s' % (tag),
+ iters=benchmark_rounds,
+ wall_time=step_time)
+
+ def benchmarkAdjustHueCpu1(self):
+ self._benchmarkAdjustHue('/cpu:0', 1)
+
+ def benchmarkAdjustHueCpuAll(self):
+ self._benchmarkAdjustHue('/cpu:0', None)
+
+ def benchmarkAdjustHueGpu(self):
+ self._benchmarkAdjustHue('/gpu:0', None)
+
class AdjustSaturationTest(test_util.TensorFlowTestCase):