aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/image_ops_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/image_ops_impl.py')
-rw-r--r--tensorflow/python/ops/image_ops_impl.py138
1 files changed, 107 insertions, 31 deletions
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 62072e1279..0a2d4e4792 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -90,22 +90,23 @@ def _is_tensor(x):
return isinstance(x, (ops.Tensor, variables.Variable))
-def _ImageDimensions(image):
+def _ImageDimensions(image, rank):
"""Returns the dimensions of an image tensor.
Args:
- image: A 3-D Tensor of shape `[height, width, channels]`.
+ image: A rank-D Tensor. For 3-D of shape: `[height, width, channels]`.
+ rank: The expected rank of the image
Returns:
- A list of `[height, width, channels]` corresponding to the dimensions of the
+ A list of corresponding to the dimensions of the
input image. Dimensions that are statically known are python integers,
otherwise they are integer scalar tensors.
"""
if image.get_shape().is_fully_defined():
return image.get_shape().as_list()
else:
- static_shape = image.get_shape().with_rank(3).as_list()
- dynamic_shape = array_ops.unstack(array_ops.shape(image), 3)
+ static_shape = image.get_shape().with_rank(rank).as_list()
+ dynamic_shape = array_ops.unstack(array_ops.shape(image), rank)
return [s if s is not None else d
for s, d in zip(static_shape, dynamic_shape)]
@@ -144,22 +145,39 @@ def _Check3DImage(image, require_static=True):
return []
-def _CheckAtLeast3DImage(image):
+def _CheckAtLeast3DImage(image, require_static=True):
"""Assert that we are working with properly shaped image.
Args:
image: >= 3-D Tensor of size [*, height, width, depth]
+ require_static: If `True`, requires that all dimensions of `image` are
+ known and non-zero.
Raises:
ValueError: if image.shape is not a [>= 3] vector.
+
+ Returns:
+ An empty list, if `image` has fully defined dimensions. Otherwise, a list
+ containing an assert op is returned.
"""
- if not image.get_shape().is_fully_defined():
+ try:
+ if image.get_shape().ndims is None:
+ image_shape = image.get_shape().with_rank(3)
+ else:
+ image_shape = image.get_shape().with_rank_at_least(3)
+ except ValueError:
+ raise ValueError("'image' must be at least three-dimensional.")
+ if require_static and not image_shape.is_fully_defined():
raise ValueError('\'image\' must be fully defined.')
- if image.get_shape().ndims < 3:
- raise ValueError('\'image\' must be at least three-dimensional.')
- if not all(x > 0 for x in image.get_shape()):
+ if any(x == 0 for x in image_shape):
raise ValueError('all dims of \'image.shape\' must be > 0: %s' %
- image.get_shape())
+ image_shape)
+ if not image_shape.is_fully_defined():
+ return [check_ops.assert_positive(array_ops.shape(image),
+ ["all dims of 'image.shape' "
+ "must be > 0."])]
+ else:
+ return []
def fix_image_flip_shape(image, result):
@@ -397,14 +415,18 @@ def pad_to_bounding_box(image, offset_height, offset_width, target_height,
`target_height` by `target_width`.
Args:
- image: 3-D tensor with shape `[height, width, channels]`
+ image: 4-D Tensor of shape `[batch, height, width, channels]` or
+ 3-D Tensor of shape `[height, width, channels]`.
offset_height: Number of rows of zeros to add on top.
offset_width: Number of columns of zeros to add on the left.
target_height: Height of output image.
target_width: Width of output image.
Returns:
- 3-D tensor of shape `[target_height, target_width, channels]`
+ If `image` was 4-D, a 4-D float Tensor of shape
+ `[batch, target_height, target_width, channels]`
+ If `image` was 3-D, a 3-D float Tensor of shape
+ `[target_height, target_width, channels]`
Raises:
ValueError: If the shape of `image` is incompatible with the `offset_*` or
@@ -414,9 +436,22 @@ def pad_to_bounding_box(image, offset_height, offset_width, target_height,
image = ops.convert_to_tensor(image, name='image')
assert_ops = []
- assert_ops += _Check3DImage(image, require_static=False)
+ assert_ops += _CheckAtLeast3DImage(image, require_static=False)
+
+ is_batch = True
+ image_shape = image.get_shape()
+ if image_shape.ndims == 3:
+ is_batch = False
+ image = array_ops.expand_dims(image, 0)
+ elif image_shape.ndims is None:
+ is_batch = False
+ image = array_ops.expand_dims(image, 0)
+ image.set_shape([None] * 4)
+ elif image_shape.ndims != 4:
+ raise ValueError('\'image\' must have either 3 or 4 dimensions.')
+
+ batch, height, width, depth = _ImageDimensions(image, rank=4)
- height, width, depth = _ImageDimensions(image)
after_padding_width = target_width - offset_width - width
after_padding_height = target_height - offset_height - height
@@ -433,15 +468,18 @@ def pad_to_bounding_box(image, offset_height, offset_width, target_height,
# Do not pad on the depth dimensions.
paddings = array_ops.reshape(
array_ops.stack([
- offset_height, after_padding_height, offset_width,
+ 0, 0, offset_height, after_padding_height, offset_width,
after_padding_width, 0, 0
- ]), [3, 2])
+ ]), [4, 2])
padded = array_ops.pad(image, paddings)
padded_shape = [None if _is_tensor(i) else i
- for i in [target_height, target_width, depth]]
+ for i in [batch, target_height, target_width, depth]]
padded.set_shape(padded_shape)
+ if not is_batch:
+ padded = array_ops.squeeze(padded, squeeze_dims=[0])
+
return padded
@@ -455,7 +493,8 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height,
`offset_height + target_height, offset_width + target_width`.
Args:
- image: 3-D tensor with shape `[height, width, channels]`
+ image: 4-D Tensor of shape `[batch, height, width, channels]` or
+ 3-D Tensor of shape `[height, width, channels]`.
offset_height: Vertical coordinate of the top-left corner of the result in
the input.
offset_width: Horizontal coordinate of the top-left corner of the result in
@@ -464,7 +503,10 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height,
target_width: Width of the result.
Returns:
- 3-D tensor of image with shape `[target_height, target_width, channels]`
+ If `image` was 4-D, a 4-D float Tensor of shape
+ `[batch, target_height, target_width, channels]`
+ If `image` was 3-D, a 3-D float Tensor of shape
+ `[target_height, target_width, channels]`
Raises:
ValueError: If the shape of `image` is incompatible with the `offset_*` or
@@ -474,9 +516,21 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height,
image = ops.convert_to_tensor(image, name='image')
assert_ops = []
- assert_ops += _Check3DImage(image, require_static=False)
+ assert_ops += _CheckAtLeast3DImage(image, require_static=False)
+
+ is_batch = True
+ image_shape = image.get_shape()
+ if image_shape.ndims == 3:
+ is_batch = False
+ image = array_ops.expand_dims(image, 0)
+ elif image_shape.ndims is None:
+ is_batch = False
+ image = array_ops.expand_dims(image, 0)
+ image.set_shape([None] * 4)
+ elif image_shape.ndims != 4:
+ raise ValueError('\'image\' must have either 3 or 4 dimensions.')
- height, width, depth = _ImageDimensions(image)
+ batch, height, width, depth = _ImageDimensions(image, rank=4)
assert_ops += _assert(offset_width >= 0, ValueError,
'offset_width must be >= 0.')
@@ -493,13 +547,16 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height,
image = control_flow_ops.with_dependencies(assert_ops, image)
cropped = array_ops.slice(image,
- array_ops.stack([offset_height, offset_width, 0]),
- array_ops.stack([target_height, target_width, -1]))
+ array_ops.stack([0, offset_height, offset_width, 0]),
+ array_ops.stack([-1, target_height, target_width, -1]))
cropped_shape = [None if _is_tensor(i) else i
- for i in [target_height, target_width, depth]]
+ for i in [batch, target_height, target_width, depth]]
cropped.set_shape(cropped_shape)
+ if not is_batch:
+ cropped = array_ops.squeeze(cropped, squeeze_dims=[0])
+
return cropped
@@ -516,7 +573,8 @@ def resize_image_with_crop_or_pad(image, target_height, target_width):
dimension.
Args:
- image: 3-D tensor of shape `[height, width, channels]`
+ image: 4-D Tensor of shape `[batch, height, width, channels]` or
+ 3-D Tensor of shape `[height, width, channels]`.
target_height: Target height.
target_width: Target width.
@@ -524,13 +582,27 @@ def resize_image_with_crop_or_pad(image, target_height, target_width):
ValueError: if `target_height` or `target_width` are zero or negative.
Returns:
- Cropped and/or padded image of shape
- `[target_height, target_width, channels]`
+ Cropped and/or padded image.
+ If `images` was 4-D, a 4-D float Tensor of shape
+ `[batch, new_height, new_width, channels]`.
+ If `images` was 3-D, a 3-D float Tensor of shape
+ `[new_height, new_width, channels]`.
"""
image = ops.convert_to_tensor(image, name='image')
+ image_shape = image.get_shape()
+ is_batch = True
+ if image_shape.ndims == 3:
+ is_batch = False
+ image = array_ops.expand_dims(image, 0)
+ elif image_shape.ndims is None:
+ is_batch = False
+ image = array_ops.expand_dims(image, 0)
+ image.set_shape([None] * 4)
+ elif image_shape.ndims != 4:
+ raise ValueError('\'image\' must have either 3 or 4 dimensions.')
assert_ops = []
- assert_ops += _Check3DImage(image, require_static=False)
+ assert_ops += _CheckAtLeast3DImage(image, require_static=False)
assert_ops += _assert(target_width > 0, ValueError,
'target_width must be > 0.')
assert_ops += _assert(target_height > 0, ValueError,
@@ -563,7 +635,7 @@ def resize_image_with_crop_or_pad(image, target_height, target_width):
else:
return x == y
- height, width, _ = _ImageDimensions(image)
+ _, height, width, _ = _ImageDimensions(image, rank=4)
width_diff = target_width - width
offset_crop_width = max_(-width_diff // 2, 0)
offset_pad_width = max_(width_diff // 2, 0)
@@ -585,7 +657,7 @@ def resize_image_with_crop_or_pad(image, target_height, target_width):
if resized.get_shape().ndims is None:
raise ValueError('resized contains no shape.')
- resized_height, resized_width, _ = _ImageDimensions(resized)
+ _, resized_height, resized_width, _ = _ImageDimensions(resized, rank=4)
assert_ops = []
assert_ops += _assert(equal_(resized_height, target_height), ValueError,
@@ -594,6 +666,10 @@ def resize_image_with_crop_or_pad(image, target_height, target_width):
'resized width is not correct.')
resized = control_flow_ops.with_dependencies(assert_ops, resized)
+
+ if not is_batch:
+ resized = array_ops.squeeze(resized, squeeze_dims=[0])
+
return resized