aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/image_ops_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/image_ops_impl.py')
-rw-r--r--tensorflow/python/ops/image_ops_impl.py74
1 files changed, 49 insertions, 25 deletions
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index bdcf420980..f27d9224c1 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_image_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
@@ -258,14 +259,14 @@ def random_flip_up_down(image, seed=None):
dimension, which is `height`. Otherwise output the image as-is.
Args:
- image: A 3-D tensor of shape `[height, width, channels].`
+ image: 4-D Tensor of shape `[batch, height, width, channels]` or
+ 3-D Tensor of shape `[height, width, channels]`.
seed: A Python integer. Used to create a random seed. See
@{tf.set_random_seed}
for behavior.
Returns:
- A 3-D tensor of the same type and shape as `image`.
-
+ A tensor of the same type and shape as `image`.
Raises:
ValueError: if the shape of `image` not supported.
"""
@@ -280,13 +281,14 @@ def random_flip_left_right(image, seed=None):
second dimension, which is `width`. Otherwise output the image as-is.
Args:
- image: A 3-D tensor of shape `[height, width, channels].`
+ image: 4-D Tensor of shape `[batch, height, width, channels]` or
+ 3-D Tensor of shape `[height, width, channels]`.
seed: A Python integer. Used to create a random seed. See
@{tf.set_random_seed}
for behavior.
Returns:
- A 3-D tensor of the same type and shape as `image`.
+ A tensor of the same type and shape as `image`.
Raises:
ValueError: if the shape of `image` not supported.
@@ -297,7 +299,8 @@ def random_flip_left_right(image, seed=None):
def _random_flip(image, flip_index, seed, scope_name):
"""Randomly (50% chance) flip an image along axis `flip_index`.
Args:
- image: A 3-D tensor of shape `[height, width, channels].`
+ image: 4-D Tensor of shape `[batch, height, width, channels]` or
+ 3-D Tensor of shape `[height, width, channels]`.
flip_index: The dimension along which to flip the image.
Vertical: 0, Horizontal: 1
seed: A Python integer. Used to create a random seed. See
@@ -306,22 +309,37 @@ def _random_flip(image, flip_index, seed, scope_name):
scope_name: Name of the scope in which the ops are added.
Returns:
- A 3-D tensor of the same type and shape as `image`.
+ A tensor of the same type and shape as `image`.
Raises:
ValueError: if the shape of `image` not supported.
"""
with ops.name_scope(None, scope_name, [image]) as scope:
image = ops.convert_to_tensor(image, name='image')
- image = _Assert3DImage(image)
- uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
- mirror_cond = math_ops.less(uniform_random, .5)
- result = control_flow_ops.cond(
- mirror_cond,
- lambda: array_ops.reverse(image, [flip_index]),
- lambda: image,
- name=scope)
- return fix_image_flip_shape(image, result)
+ image = _AssertAtLeast3DImage(image)
+ shape = image.get_shape()
+ if shape.ndims == 3 or shape.ndims is None:
+ uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
+ mirror_cond = math_ops.less(uniform_random, .5)
+ result = control_flow_ops.cond(
+ mirror_cond,
+ lambda: array_ops.reverse(image, [flip_index]),
+ lambda: image,
+ name=scope
+ )
+ return fix_image_flip_shape(image, result)
+ elif shape.ndims == 4:
+ uniform_random = random_ops.random_uniform(
+ [array_ops.shape(image)[0]], 0, 1.0, seed=seed
+ )
+ mirror_cond = math_ops.less(uniform_random, .5)
+ return array_ops.where(
+ mirror_cond,
+ image,
+ functional_ops.map_fn(lambda x: array_ops.reverse(x, [flip_index]), image, dtype=image.dtype)
+ )
+ else:
+ raise ValueError('\'image\' must have either 3 or 4 dimensions.')
@tf_export('image.flip_left_right')
@@ -1634,13 +1652,13 @@ def is_jpeg(contents, name=None):
@tf_export('image.decode_image')
-def decode_image(contents, channels=None, name=None):
+def decode_image(contents, channels=None, dtype=dtypes.uint8, name=None):
"""Convenience function for `decode_bmp`, `decode_gif`, `decode_jpeg`,
and `decode_png`.
Detects whether an image is a BMP, GIF, JPEG, or PNG, and performs the
- appropriate operation to convert the input bytes `string` into a `Tensor` of
- type `uint8`.
+ appropriate operation to convert the input bytes `string` into a `Tensor`
+ of type `dtype`.
Note: `decode_gif` returns a 4-D array `[num_frames, height, width, 3]`, as
opposed to `decode_bmp`, `decode_jpeg` and `decode_png`, which return 3-D
@@ -1652,10 +1670,11 @@ def decode_image(contents, channels=None, name=None):
contents: 0-D `string`. The encoded image bytes.
channels: An optional `int`. Defaults to `0`. Number of color channels for
the decoded image.
+ dtype: The desired DType of the returned `Tensor`.
name: A name for the operation (optional)
Returns:
- `Tensor` with type `uint8` with shape `[height, width, num_channels]` for
+ `Tensor` with type `dtype` and shape `[height, width, num_channels]` for
BMP, JPEG, and PNG images and shape `[num_frames, height, width, 3]` for
GIF images.
@@ -1679,7 +1698,7 @@ def decode_image(contents, channels=None, name=None):
channels_msg = 'Channels must be in (None, 0, 3) when decoding BMP images'
assert_channels = control_flow_ops.Assert(good_channels, [channels_msg])
with ops.control_dependencies([assert_decode, assert_channels]):
- return gen_image_ops.decode_bmp(contents)
+ return convert_image_dtype(gen_image_ops.decode_bmp(contents), dtype)
def _gif():
# Create assert to make sure that channels is not set to 1
@@ -1692,7 +1711,7 @@ def decode_image(contents, channels=None, name=None):
channels_msg = 'Channels must be in (None, 0, 3) when decoding GIF images'
assert_channels = control_flow_ops.Assert(good_channels, [channels_msg])
with ops.control_dependencies([assert_channels]):
- return gen_image_ops.decode_gif(contents)
+ return convert_image_dtype(gen_image_ops.decode_gif(contents), dtype)
def check_gif():
# Create assert op to check that bytes are GIF decodable
@@ -1701,7 +1720,11 @@ def decode_image(contents, channels=None, name=None):
def _png():
"""Decodes a PNG image."""
- return gen_image_ops.decode_png(contents, channels)
+ return convert_image_dtype(
+ gen_image_ops.decode_png(contents, channels,
+ dtype=dtypes.uint8
+ if dtype == dtypes.uint8
+ else dtypes.uint16), dtype)
def check_png():
"""Checks if an image is PNG."""
@@ -1717,7 +1740,8 @@ def decode_image(contents, channels=None, name=None):
'images')
assert_channels = control_flow_ops.Assert(good_channels, [channels_msg])
with ops.control_dependencies([assert_channels]):
- return gen_image_ops.decode_jpeg(contents, channels)
+ return convert_image_dtype(
+ gen_image_ops.decode_jpeg(contents, channels), dtype)
# Decode normal JPEG images (start with \xff\xd8\xff\xe0)
# as well as JPEG images with EXIF data (start with \xff\xd8\xff\xe1).
@@ -1878,7 +1902,7 @@ def sample_distorted_bounding_box(image_size,
width / height within this range.
area_range: An optional list of `floats`. Defaults to `[0.05, 1]`.
The cropped area of the image must contain a fraction of the
- supplied image within in this range.
+ supplied image within this range.
max_attempts: An optional `int`. Defaults to `100`.
Number of attempts at generating a cropped region of the image
of the specified constraints. After `max_attempts` failures, return the