diff options
Diffstat (limited to 'tensorflow/python/ops/image_ops_impl.py')
-rw-r--r-- | tensorflow/python/ops/image_ops_impl.py | 74 |
1 files changed, 49 insertions, 25 deletions
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index bdcf420980..f27d9224c1 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gen_image_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops @@ -258,14 +259,14 @@ def random_flip_up_down(image, seed=None): dimension, which is `height`. Otherwise output the image as-is. Args: - image: A 3-D tensor of shape `[height, width, channels].` + image: 4-D Tensor of shape `[batch, height, width, channels]` or + 3-D Tensor of shape `[height, width, channels]`. seed: A Python integer. Used to create a random seed. See @{tf.set_random_seed} for behavior. Returns: - A 3-D tensor of the same type and shape as `image`. - + A tensor of the same type and shape as `image`. Raises: ValueError: if the shape of `image` not supported. """ @@ -280,13 +281,14 @@ def random_flip_left_right(image, seed=None): second dimension, which is `width`. Otherwise output the image as-is. Args: - image: A 3-D tensor of shape `[height, width, channels].` + image: 4-D Tensor of shape `[batch, height, width, channels]` or + 3-D Tensor of shape `[height, width, channels]`. seed: A Python integer. Used to create a random seed. See @{tf.set_random_seed} for behavior. Returns: - A 3-D tensor of the same type and shape as `image`. + A tensor of the same type and shape as `image`. Raises: ValueError: if the shape of `image` not supported. @@ -297,7 +299,8 @@ def random_flip_left_right(image, seed=None): def _random_flip(image, flip_index, seed, scope_name): """Randomly (50% chance) flip an image along axis `flip_index`. Args: - image: A 3-D tensor of shape `[height, width, channels].` + image: 4-D Tensor of shape `[batch, height, width, channels]` or + 3-D Tensor of shape `[height, width, channels]`. flip_index: The dimension along which to flip the image. Vertical: 0, Horizontal: 1 seed: A Python integer. Used to create a random seed. See @@ -306,22 +309,37 @@ def _random_flip(image, flip_index, seed, scope_name): scope_name: Name of the scope in which the ops are added. Returns: - A 3-D tensor of the same type and shape as `image`. + A tensor of the same type and shape as `image`. Raises: ValueError: if the shape of `image` not supported. """ with ops.name_scope(None, scope_name, [image]) as scope: image = ops.convert_to_tensor(image, name='image') - image = _Assert3DImage(image) - uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed) - mirror_cond = math_ops.less(uniform_random, .5) - result = control_flow_ops.cond( - mirror_cond, - lambda: array_ops.reverse(image, [flip_index]), - lambda: image, - name=scope) - return fix_image_flip_shape(image, result) + image = _AssertAtLeast3DImage(image) + shape = image.get_shape() + if shape.ndims == 3 or shape.ndims is None: + uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed) + mirror_cond = math_ops.less(uniform_random, .5) + result = control_flow_ops.cond( + mirror_cond, + lambda: array_ops.reverse(image, [flip_index]), + lambda: image, + name=scope + ) + return fix_image_flip_shape(image, result) + elif shape.ndims == 4: + uniform_random = random_ops.random_uniform( + [array_ops.shape(image)[0]], 0, 1.0, seed=seed + ) + mirror_cond = math_ops.less(uniform_random, .5) + return array_ops.where( + mirror_cond, + image, + functional_ops.map_fn(lambda x: array_ops.reverse(x, [flip_index]), image, dtype=image.dtype) + ) + else: + raise ValueError('\'image\' must have either 3 or 4 dimensions.') @tf_export('image.flip_left_right') @@ -1634,13 +1652,13 @@ def is_jpeg(contents, name=None): @tf_export('image.decode_image') -def decode_image(contents, channels=None, name=None): +def decode_image(contents, channels=None, dtype=dtypes.uint8, name=None): """Convenience function for `decode_bmp`, `decode_gif`, `decode_jpeg`, and `decode_png`. Detects whether an image is a BMP, GIF, JPEG, or PNG, and performs the - appropriate operation to convert the input bytes `string` into a `Tensor` of - type `uint8`. + appropriate operation to convert the input bytes `string` into a `Tensor` + of type `dtype`. Note: `decode_gif` returns a 4-D array `[num_frames, height, width, 3]`, as opposed to `decode_bmp`, `decode_jpeg` and `decode_png`, which return 3-D @@ -1652,10 +1670,11 @@ def decode_image(contents, channels=None, name=None): contents: 0-D `string`. The encoded image bytes. channels: An optional `int`. Defaults to `0`. Number of color channels for the decoded image. + dtype: The desired DType of the returned `Tensor`. name: A name for the operation (optional) Returns: - `Tensor` with type `uint8` with shape `[height, width, num_channels]` for + `Tensor` with type `dtype` and shape `[height, width, num_channels]` for BMP, JPEG, and PNG images and shape `[num_frames, height, width, 3]` for GIF images. @@ -1679,7 +1698,7 @@ def decode_image(contents, channels=None, name=None): channels_msg = 'Channels must be in (None, 0, 3) when decoding BMP images' assert_channels = control_flow_ops.Assert(good_channels, [channels_msg]) with ops.control_dependencies([assert_decode, assert_channels]): - return gen_image_ops.decode_bmp(contents) + return convert_image_dtype(gen_image_ops.decode_bmp(contents), dtype) def _gif(): # Create assert to make sure that channels is not set to 1 @@ -1692,7 +1711,7 @@ def decode_image(contents, channels=None, name=None): channels_msg = 'Channels must be in (None, 0, 3) when decoding GIF images' assert_channels = control_flow_ops.Assert(good_channels, [channels_msg]) with ops.control_dependencies([assert_channels]): - return gen_image_ops.decode_gif(contents) + return convert_image_dtype(gen_image_ops.decode_gif(contents), dtype) def check_gif(): # Create assert op to check that bytes are GIF decodable @@ -1701,7 +1720,11 @@ def decode_image(contents, channels=None, name=None): def _png(): """Decodes a PNG image.""" - return gen_image_ops.decode_png(contents, channels) + return convert_image_dtype( + gen_image_ops.decode_png(contents, channels, + dtype=dtypes.uint8 + if dtype == dtypes.uint8 + else dtypes.uint16), dtype) def check_png(): """Checks if an image is PNG.""" @@ -1717,7 +1740,8 @@ def decode_image(contents, channels=None, name=None): 'images') assert_channels = control_flow_ops.Assert(good_channels, [channels_msg]) with ops.control_dependencies([assert_channels]): - return gen_image_ops.decode_jpeg(contents, channels) + return convert_image_dtype( + gen_image_ops.decode_jpeg(contents, channels), dtype) # Decode normal JPEG images (start with \xff\xd8\xff\xe0) # as well as JPEG images with EXIF data (start with \xff\xd8\xff\xe1). @@ -1878,7 +1902,7 @@ def sample_distorted_bounding_box(image_size, width / height within this range. area_range: An optional list of `floats`. Defaults to `[0.05, 1]`. The cropped area of the image must contain a fraction of the - supplied image within in this range. + supplied image within this range. max_attempts: An optional `int`. Defaults to `100`. Number of attempts at generating a cropped region of the image of the specified constraints. After `max_attempts` failures, return the |