aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/image/python
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-15 15:28:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-15 15:31:39 -0700
commitbc68dc843f43e6afd9ef2ba207cfa3f0f4a9db4e (patch)
tree5bcef503755f375922489d9c32d1a9b9cf85caa0 /tensorflow/contrib/image/python
parent30868ef86771acf8632bd8991b65f47e5ce3756e (diff)
Add ops that perform color transforms (including changing value, saturation and hue) in YIQ space.
PiperOrigin-RevId: 168897736
Diffstat (limited to 'tensorflow/contrib/image/python')
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py338
-rw-r--r--tensorflow/contrib/image/python/ops/distort_image_ops.py138
2 files changed, 476 insertions, 0 deletions
diff --git a/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py
new file mode 100644
index 0000000000..b85f19d29b
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py
@@ -0,0 +1,338 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for python distort_image_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.contrib.image.python.ops import distort_image_ops
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
+
+
+# TODO(huangyp): also measure the differences between AdjustHsvInYiq and
+# AdjustHsv in core.
+class AdjustHueInYiqTest(test_util.TensorFlowTestCase):
+
+ def _adjust_hue_in_yiq_np(self, x_np, delta_h):
+ """Rotate hue in YIQ space.
+
+ Mathematically we first convert rgb color to yiq space, rotate the hue
+ degrees, and then convert back to rgb.
+
+ Args:
+ x_np: input x with last dimension = 3.
+ delta_h: degree of hue rotation, in radians.
+
+ Returns:
+ Adjusted y with the same shape as x_np.
+ """
+ self.assertEqual(x_np.shape[-1], 3)
+ x_v = x_np.reshape([-1, 3])
+ y_v = np.ndarray(x_v.shape, dtype=x_v.dtype)
+ u = np.cos(delta_h)
+ w = np.sin(delta_h)
+ # Projection matrix from RGB to YIQ. Numbers from wikipedia
+ # https://en.wikipedia.org/wiki/YIQ
+ tyiq = np.array([[0.299, 0.587, 0.114], [0.596, -0.274, -0.322],
+ [0.211, -0.523, 0.312]])
+ y_v = np.dot(x_v, tyiq.T)
+ # Hue rotation matrix in YIQ space.
+ hue_rotation = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]])
+ y_v = np.dot(y_v, hue_rotation.T)
+ # Projecting back to RGB space.
+ y_v = np.dot(y_v, np.linalg.inv(tyiq).T)
+ return y_v.reshape(x_np.shape)
+
+ def _adjust_hue_in_yiq_tf(self, x_np, delta_h):
+ with self.test_session(use_gpu=True):
+ x = constant_op.constant(x_np)
+ y = distort_image_ops.adjust_hsv_in_yiq(x, delta_h, 1, 1)
+ y_tf = y.eval()
+ return y_tf
+
+ def test_adjust_random_hue_in_yiq(self):
+ x_shapes = [
+ [2, 2, 3],
+ [4, 2, 3],
+ [2, 4, 3],
+ [2, 5, 3],
+ [1000, 1, 3],
+ ]
+ test_styles = [
+ 'all_random',
+ 'rg_same',
+ 'rb_same',
+ 'gb_same',
+ 'rgb_same',
+ ]
+ for x_shape in x_shapes:
+ for test_style in test_styles:
+ x_np = np.random.rand(*x_shape) * 255.
+ delta_h = (np.random.rand() * 2.0 - 1.0) * np.pi
+ if test_style == 'all_random':
+ pass
+ elif test_style == 'rg_same':
+ x_np[..., 1] = x_np[..., 0]
+ elif test_style == 'rb_same':
+ x_np[..., 2] = x_np[..., 0]
+ elif test_style == 'gb_same':
+ x_np[..., 2] = x_np[..., 1]
+ elif test_style == 'rgb_same':
+ x_np[..., 1] = x_np[..., 0]
+ x_np[..., 2] = x_np[..., 0]
+ else:
+ raise AssertionError('Invalid test style: %s' % (test_style))
+ y_np = self._adjust_hue_in_yiq_np(x_np, delta_h)
+ y_tf = self._adjust_hue_in_yiq_tf(x_np, delta_h)
+ self.assertAllClose(y_tf, y_np, rtol=2e-4, atol=1e-4)
+
+ def test_invalid_shapes(self):
+ x_np = np.random.rand(2, 3) * 255.
+ delta_h = np.random.rand() * 2.0 - 1.0
+ with self.assertRaisesRegexp(ValueError, 'Shape must be at least rank 3'):
+ self._adjust_hue_in_yiq_tf(x_np, delta_h)
+ x_np = np.random.rand(4, 2, 4) * 255.
+ delta_h = np.random.rand() * 2.0 - 1.0
+ with self.assertRaisesOpError('input must have 3 channels but instead has '
+ '4 channels'):
+ self._adjust_hue_in_yiq_tf(x_np, delta_h)
+
+
+class AdjustValueInYiqTest(test_util.TensorFlowTestCase):
+
+ def _adjust_value_in_yiq_np(self, x_np, scale):
+ return x_np * scale
+
+ def _adjust_value_in_yiq_tf(self, x_np, scale):
+ with self.test_session(use_gpu=True):
+ x = constant_op.constant(x_np)
+ y = distort_image_ops.adjust_hsv_in_yiq(x, 0, 1, scale)
+ y_tf = y.eval()
+ return y_tf
+
+ def test_adjust_random_value_in_yiq(self):
+ x_shapes = [
+ [2, 2, 3],
+ [4, 2, 3],
+ [2, 4, 3],
+ [2, 5, 3],
+ [1000, 1, 3],
+ ]
+ test_styles = [
+ 'all_random',
+ 'rg_same',
+ 'rb_same',
+ 'gb_same',
+ 'rgb_same',
+ ]
+ for x_shape in x_shapes:
+ for test_style in test_styles:
+ x_np = np.random.rand(*x_shape) * 255.
+ scale = np.random.rand() * 2.0 - 1.0
+ if test_style == 'all_random':
+ pass
+ elif test_style == 'rg_same':
+ x_np[..., 1] = x_np[..., 0]
+ elif test_style == 'rb_same':
+ x_np[..., 2] = x_np[..., 0]
+ elif test_style == 'gb_same':
+ x_np[..., 2] = x_np[..., 1]
+ elif test_style == 'rgb_same':
+ x_np[..., 1] = x_np[..., 0]
+ x_np[..., 2] = x_np[..., 0]
+ else:
+ raise AssertionError('Invalid test style: %s' % (test_style))
+ y_np = self._adjust_value_in_yiq_np(x_np, scale)
+ y_tf = self._adjust_value_in_yiq_tf(x_np, scale)
+ self.assertAllClose(y_tf, y_np, rtol=2e-5, atol=1e-5)
+
+ def test_invalid_shapes(self):
+ x_np = np.random.rand(2, 3) * 255.
+ scale = np.random.rand() * 2.0 - 1.0
+ with self.assertRaisesRegexp(ValueError, 'Shape must be at least rank 3'):
+ self._adjust_value_in_yiq_tf(x_np, scale)
+ x_np = np.random.rand(4, 2, 4) * 255.
+ scale = np.random.rand() * 2.0 - 1.0
+ with self.assertRaisesOpError('input must have 3 channels but instead has '
+ '4 channels'):
+ self._adjust_value_in_yiq_tf(x_np, scale)
+
+
+class AdjustSaturationInYiqTest(test_util.TensorFlowTestCase):
+
+ def _adjust_saturation_in_yiq_tf(self, x_np, scale):
+ with self.test_session(use_gpu=True):
+ x = constant_op.constant(x_np)
+ y = distort_image_ops.adjust_hsv_in_yiq(x, 0, scale, 1)
+ y_tf = y.eval()
+ return y_tf
+
+ def _adjust_saturation_in_yiq_np(self, x_np, scale):
+ """Adjust saturation using linear interpolation."""
+ rgb_weights = np.array([0.299, 0.587, 0.114])
+ gray = np.sum(x_np * rgb_weights, axis=-1, keepdims=True)
+ y_v = x_np * scale + gray * (1 - scale)
+ return y_v
+
+ def test_adjust_random_saturation_in_yiq(self):
+ x_shapes = [
+ [2, 2, 3],
+ [4, 2, 3],
+ [2, 4, 3],
+ [2, 5, 3],
+ [1000, 1, 3],
+ ]
+ test_styles = [
+ 'all_random',
+ 'rg_same',
+ 'rb_same',
+ 'gb_same',
+ 'rgb_same',
+ ]
+ with self.test_session():
+ for x_shape in x_shapes:
+ for test_style in test_styles:
+ x_np = np.random.rand(*x_shape) * 255.
+ scale = np.random.rand() * 2.0 - 1.0
+ if test_style == 'all_random':
+ pass
+ elif test_style == 'rg_same':
+ x_np[..., 1] = x_np[..., 0]
+ elif test_style == 'rb_same':
+ x_np[..., 2] = x_np[..., 0]
+ elif test_style == 'gb_same':
+ x_np[..., 2] = x_np[..., 1]
+ elif test_style == 'rgb_same':
+ x_np[..., 1] = x_np[..., 0]
+ x_np[..., 2] = x_np[..., 0]
+ else:
+ raise AssertionError('Invalid test style: %s' % (test_style))
+ y_baseline = self._adjust_saturation_in_yiq_np(x_np, scale)
+ y_tf = self._adjust_saturation_in_yiq_tf(x_np, scale)
+ self.assertAllClose(y_tf, y_baseline, rtol=2e-5, atol=1e-5)
+
+ def test_invalid_shapes(self):
+ x_np = np.random.rand(2, 3) * 255.
+ scale = np.random.rand() * 2.0 - 1.0
+ with self.assertRaisesRegexp(ValueError, 'Shape must be at least rank 3'):
+ self._adjust_saturation_in_yiq_tf(x_np, scale)
+ x_np = np.random.rand(4, 2, 4) * 255.
+ scale = np.random.rand() * 2.0 - 1.0
+ with self.assertRaisesOpError('input must have 3 channels but instead has '
+ '4 channels'):
+ self._adjust_saturation_in_yiq_tf(x_np, scale)
+
+
+class AdjustHueInYiqBenchmark(test.Benchmark):
+
+ def _benchmark_adjust_hue_in_yiq(self, device, cpu_count):
+ image_shape = [299, 299, 3]
+ warmup_rounds = 100
+ benchmark_rounds = 1000
+ config = config_pb2.ConfigProto()
+ if cpu_count is not None:
+ config.inter_op_parallelism_threads = 1
+ config.intra_op_parallelism_threads = cpu_count
+ with session.Session('', graph=ops.Graph(), config=config) as sess:
+ with ops.device(device):
+ inputs = variables.Variable(
+ random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255,
+ trainable=False,
+ dtype=dtypes.float32)
+ delta = constant_op.constant(0.1, dtype=dtypes.float32)
+ outputs = distort_image_ops.adjust_hsv_in_yiq(inputs, delta, 1, 1)
+ run_op = control_flow_ops.group(outputs)
+ sess.run(variables.global_variables_initializer())
+ for i in xrange(warmup_rounds + benchmark_rounds):
+ if i == warmup_rounds:
+ start = time.time()
+ sess.run(run_op)
+ end = time.time()
+ step_time = (end - start) / benchmark_rounds
+ tag = device + '_%s' % (cpu_count if cpu_count is not None else 'all')
+ print('benchmarkadjust_hue_in_yiq_299_299_3_%s step_time: %.2f us' %
+ (tag, step_time * 1e6))
+ self.report_benchmark(
+ name='benchmarkadjust_hue_in_yiq_299_299_3_%s' % (tag),
+ iters=benchmark_rounds,
+ wall_time=step_time)
+
+ def benchmark_adjust_hue_in_yiqCpu1(self):
+ self._benchmark_adjust_hue_in_yiq('/cpu:0', 1)
+
+ def benchmark_adjust_hue_in_yiqCpuAll(self):
+ self._benchmark_adjust_hue_in_yiq('/cpu:0', None)
+
+
+class AdjustSaturationInYiqBenchmark(test.Benchmark):
+
+ def _benchmark_adjust_saturation_in_yiq(self, device, cpu_count):
+ image_shape = [299, 299, 3]
+ warmup_rounds = 100
+ benchmark_rounds = 1000
+ config = config_pb2.ConfigProto()
+ if cpu_count is not None:
+ config.inter_op_parallelism_threads = 1
+ config.intra_op_parallelism_threads = cpu_count
+ with session.Session('', graph=ops.Graph(), config=config) as sess:
+ with ops.device(device):
+ inputs = variables.Variable(
+ random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255,
+ trainable=False,
+ dtype=dtypes.float32)
+ scale = constant_op.constant(0.1, dtype=dtypes.float32)
+ outputs = distort_image_ops.adjust_hsv_in_yiq(inputs, 0, scale, 1)
+ run_op = control_flow_ops.group(outputs)
+ sess.run(variables.global_variables_initializer())
+ for _ in xrange(warmup_rounds):
+ sess.run(run_op)
+ start = time.time()
+ for _ in xrange(benchmark_rounds):
+ sess.run(run_op)
+ end = time.time()
+ step_time = (end - start) / benchmark_rounds
+ tag = '%s' % (cpu_count) if cpu_count is not None else '_all'
+ print('benchmarkAdjustSaturationInYiq_299_299_3_cpu%s step_time: %.2f us' %
+ (tag, step_time * 1e6))
+ self.report_benchmark(
+ name='benchmarkAdjustSaturationInYiq_299_299_3_cpu%s' % (tag),
+ iters=benchmark_rounds,
+ wall_time=step_time)
+
+ def benchmark_adjust_saturation_in_yiq_cpu1(self):
+ self._benchmark_adjust_saturation_in_yiq('/cpu:0', 1)
+
+ def benchmark_adjust_saturation_in_yiq_cpu_all(self):
+ self._benchmark_adjust_saturation_in_yiq('/cpu:0', None)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/contrib/image/python/ops/distort_image_ops.py b/tensorflow/contrib/image/python/ops/distort_image_ops.py
new file mode 100644
index 0000000000..39f023a2b4
--- /dev/null
+++ b/tensorflow/contrib/image/python/ops/distort_image_ops.py
@@ -0,0 +1,138 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python layer for distort_image_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.util import loader
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import image_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import resource_loader
+
+_distort_image_ops = loader.load_op_library(
+ resource_loader.get_path_to_datafile('_distort_image_ops.so'))
+
+
+# pylint: disable=invalid-name
+def random_hsv_in_yiq(image,
+ max_delta_hue=0,
+ lower_saturation=1,
+ upper_saturation=1,
+ lower_value=1,
+ upper_value=1,
+ seed=None):
+ """Adjust hue, saturation, value of an RGB image randomly in YIQ color space.
+
+ Equivalent to `adjust_yiq_hsv()` but uses a `delta_h` randomly
+ picked in the interval `[-max_delta_hue, max_delta_hue]`, a `scale_saturation`
+ randomly picked in the interval `[lower_saturation, upper_saturation]`, and
+ a `scale_value` randomly picked in the interval
+ `[lower_saturation, upper_saturation]`.
+
+ Args:
+ image: RGB image or images. Size of the last dimension must be 3.
+ max_delta_hue: float. Maximum value for the random delta_hue. Passing 0
+ disables adjusting hue.
+ lower_saturation: float. Lower bound for the random scale_saturation.
+ upper_saturation: float. Upper bound for the random scale_saturation.
+ lower_value: float. Lower bound for the random scale_value.
+ upper_value: float. Upper bound for the random scale_value.
+ seed: An operation-specific seed. It will be used in conjunction
+ with the graph-level seed to determine the real seeds that will be
+ used in this operation. Please see the documentation of
+ set_random_seed for its interaction with the graph-level random seed.
+
+ Returns:
+ 3-D float tensor of shape `[height, width, channels]`.
+
+ Raises:
+ ValueError: if `max_delta`, `lower_saturation`, `upper_saturation`,
+ `lower_value`, or `upper_Value` is invalid.
+ """
+ if max_delta_hue < 0:
+ raise ValueError('max_delta must be non-negative.')
+
+ if lower_saturation < 0:
+ raise ValueError('lower_saturation must be non-negative.')
+
+ if lower_value < 0:
+ raise ValueError('lower_value must be non-negative.')
+
+ if lower_saturation > upper_saturation:
+ raise ValueError('lower_saturation must be < upper_saturation.')
+
+ if lower_value > upper_value:
+ raise ValueError('lower_value must be < upper_value.')
+
+ if max_delta_hue == 0:
+ delta_hue = 0
+ else:
+ delta_hue = random_ops.random_uniform(
+ [], -max_delta_hue, max_delta_hue, seed=seed)
+ if lower_saturation == upper_saturation:
+ scale_saturation = lower_saturation
+ else:
+ scale_saturation = random_ops.random_uniform(
+ [], lower_saturation, upper_saturation, seed=seed)
+ if lower_value == upper_value:
+ scale_value = lower_value
+ else:
+ scale_value = random_ops.random_uniform(
+ [], lower_value, upper_value, seed=seed)
+ return adjust_hsv_in_yiq(image, delta_hue, scale_saturation, scale_value)
+
+
+def adjust_hsv_in_yiq(image,
+ delta_hue=0,
+ scale_saturation=1,
+ scale_value=1,
+ name=None):
+ """Adjust hue, saturation, value of an RGB image in YIQ color space.
+
+ This is a convenience method that converts an RGB image to float
+ representation, converts it to YIQ, rotates the color around the Y channel by
+ delta_hue in radians, scales the chrominance channels (I, Q) by
+ scale_saturation, scales all channels (Y, I, Q) by scale_value,
+ converts back to RGB, and then back to the original data type.
+
+ `image` is an RGB image. The image hue is adjusted by converting the
+ image to YIQ, rotating around the luminance channel (Y) by
+ `delta_hue` in radians, multiplying the chrominance channels (I, Q) by
+ `scale_saturation`, and multiplying all channels (Y, I, Q) by
+ `scale_value`. The image is then converted back to RGB.
+
+ Args:
+ image: RGB image or images. Size of the last dimension must be 3.
+ delta_hue: float, the hue rotation amount, in radians.
+ scale_saturation: float, factor to multiply the saturation by.
+ scale_value: float, factor to multiply the value by.
+ name: A name for this operation (optional).
+
+ Returns:
+ Adjusted image(s), same shape and DType as `image`.
+ """
+ with ops.name_scope(name, 'adjust_hsv_in_yiq', [image]) as name:
+ image = ops.convert_to_tensor(image, name='image')
+ # Remember original dtype to so we can convert back if needed
+ orig_dtype = image.dtype
+ flt_image = image_ops.convert_image_dtype(image, dtypes.float32)
+
+ rgb_altered = _distort_image_ops.adjust_hsv_in_yiq(
+ flt_image, delta_hue, scale_saturation, scale_value)
+
+ return image_ops.convert_image_dtype(rgb_altered, orig_dtype)