diff options
Diffstat (limited to 'tensorflow/python/ops/image_ops_impl.py')
-rw-r--r-- | tensorflow/python/ops/image_ops_impl.py | 138 |
1 files changed, 107 insertions, 31 deletions
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 62072e1279..0a2d4e4792 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -90,22 +90,23 @@ def _is_tensor(x): return isinstance(x, (ops.Tensor, variables.Variable)) -def _ImageDimensions(image): +def _ImageDimensions(image, rank): """Returns the dimensions of an image tensor. Args: - image: A 3-D Tensor of shape `[height, width, channels]`. + image: A rank-D Tensor. For 3-D of shape: `[height, width, channels]`. + rank: The expected rank of the image Returns: - A list of `[height, width, channels]` corresponding to the dimensions of the + A list of corresponding to the dimensions of the input image. Dimensions that are statically known are python integers, otherwise they are integer scalar tensors. """ if image.get_shape().is_fully_defined(): return image.get_shape().as_list() else: - static_shape = image.get_shape().with_rank(3).as_list() - dynamic_shape = array_ops.unstack(array_ops.shape(image), 3) + static_shape = image.get_shape().with_rank(rank).as_list() + dynamic_shape = array_ops.unstack(array_ops.shape(image), rank) return [s if s is not None else d for s, d in zip(static_shape, dynamic_shape)] @@ -144,22 +145,39 @@ def _Check3DImage(image, require_static=True): return [] -def _CheckAtLeast3DImage(image): +def _CheckAtLeast3DImage(image, require_static=True): """Assert that we are working with properly shaped image. Args: image: >= 3-D Tensor of size [*, height, width, depth] + require_static: If `True`, requires that all dimensions of `image` are + known and non-zero. Raises: ValueError: if image.shape is not a [>= 3] vector. + + Returns: + An empty list, if `image` has fully defined dimensions. Otherwise, a list + containing an assert op is returned. """ - if not image.get_shape().is_fully_defined(): + try: + if image.get_shape().ndims is None: + image_shape = image.get_shape().with_rank(3) + else: + image_shape = image.get_shape().with_rank_at_least(3) + except ValueError: + raise ValueError("'image' must be at least three-dimensional.") + if require_static and not image_shape.is_fully_defined(): raise ValueError('\'image\' must be fully defined.') - if image.get_shape().ndims < 3: - raise ValueError('\'image\' must be at least three-dimensional.') - if not all(x > 0 for x in image.get_shape()): + if any(x == 0 for x in image_shape): raise ValueError('all dims of \'image.shape\' must be > 0: %s' % - image.get_shape()) + image_shape) + if not image_shape.is_fully_defined(): + return [check_ops.assert_positive(array_ops.shape(image), + ["all dims of 'image.shape' " + "must be > 0."])] + else: + return [] def fix_image_flip_shape(image, result): @@ -397,14 +415,18 @@ def pad_to_bounding_box(image, offset_height, offset_width, target_height, `target_height` by `target_width`. Args: - image: 3-D tensor with shape `[height, width, channels]` + image: 4-D Tensor of shape `[batch, height, width, channels]` or + 3-D Tensor of shape `[height, width, channels]`. offset_height: Number of rows of zeros to add on top. offset_width: Number of columns of zeros to add on the left. target_height: Height of output image. target_width: Width of output image. Returns: - 3-D tensor of shape `[target_height, target_width, channels]` + If `image` was 4-D, a 4-D float Tensor of shape + `[batch, target_height, target_width, channels]` + If `image` was 3-D, a 3-D float Tensor of shape + `[target_height, target_width, channels]` Raises: ValueError: If the shape of `image` is incompatible with the `offset_*` or @@ -414,9 +436,22 @@ def pad_to_bounding_box(image, offset_height, offset_width, target_height, image = ops.convert_to_tensor(image, name='image') assert_ops = [] - assert_ops += _Check3DImage(image, require_static=False) + assert_ops += _CheckAtLeast3DImage(image, require_static=False) + + is_batch = True + image_shape = image.get_shape() + if image_shape.ndims == 3: + is_batch = False + image = array_ops.expand_dims(image, 0) + elif image_shape.ndims is None: + is_batch = False + image = array_ops.expand_dims(image, 0) + image.set_shape([None] * 4) + elif image_shape.ndims != 4: + raise ValueError('\'image\' must have either 3 or 4 dimensions.') + + batch, height, width, depth = _ImageDimensions(image, rank=4) - height, width, depth = _ImageDimensions(image) after_padding_width = target_width - offset_width - width after_padding_height = target_height - offset_height - height @@ -433,15 +468,18 @@ def pad_to_bounding_box(image, offset_height, offset_width, target_height, # Do not pad on the depth dimensions. paddings = array_ops.reshape( array_ops.stack([ - offset_height, after_padding_height, offset_width, + 0, 0, offset_height, after_padding_height, offset_width, after_padding_width, 0, 0 - ]), [3, 2]) + ]), [4, 2]) padded = array_ops.pad(image, paddings) padded_shape = [None if _is_tensor(i) else i - for i in [target_height, target_width, depth]] + for i in [batch, target_height, target_width, depth]] padded.set_shape(padded_shape) + if not is_batch: + padded = array_ops.squeeze(padded, squeeze_dims=[0]) + return padded @@ -455,7 +493,8 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height, `offset_height + target_height, offset_width + target_width`. Args: - image: 3-D tensor with shape `[height, width, channels]` + image: 4-D Tensor of shape `[batch, height, width, channels]` or + 3-D Tensor of shape `[height, width, channels]`. offset_height: Vertical coordinate of the top-left corner of the result in the input. offset_width: Horizontal coordinate of the top-left corner of the result in @@ -464,7 +503,10 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height, target_width: Width of the result. Returns: - 3-D tensor of image with shape `[target_height, target_width, channels]` + If `image` was 4-D, a 4-D float Tensor of shape + `[batch, target_height, target_width, channels]` + If `image` was 3-D, a 3-D float Tensor of shape + `[target_height, target_width, channels]` Raises: ValueError: If the shape of `image` is incompatible with the `offset_*` or @@ -474,9 +516,21 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height, image = ops.convert_to_tensor(image, name='image') assert_ops = [] - assert_ops += _Check3DImage(image, require_static=False) + assert_ops += _CheckAtLeast3DImage(image, require_static=False) + + is_batch = True + image_shape = image.get_shape() + if image_shape.ndims == 3: + is_batch = False + image = array_ops.expand_dims(image, 0) + elif image_shape.ndims is None: + is_batch = False + image = array_ops.expand_dims(image, 0) + image.set_shape([None] * 4) + elif image_shape.ndims != 4: + raise ValueError('\'image\' must have either 3 or 4 dimensions.') - height, width, depth = _ImageDimensions(image) + batch, height, width, depth = _ImageDimensions(image, rank=4) assert_ops += _assert(offset_width >= 0, ValueError, 'offset_width must be >= 0.') @@ -493,13 +547,16 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height, image = control_flow_ops.with_dependencies(assert_ops, image) cropped = array_ops.slice(image, - array_ops.stack([offset_height, offset_width, 0]), - array_ops.stack([target_height, target_width, -1])) + array_ops.stack([0, offset_height, offset_width, 0]), + array_ops.stack([-1, target_height, target_width, -1])) cropped_shape = [None if _is_tensor(i) else i - for i in [target_height, target_width, depth]] + for i in [batch, target_height, target_width, depth]] cropped.set_shape(cropped_shape) + if not is_batch: + cropped = array_ops.squeeze(cropped, squeeze_dims=[0]) + return cropped @@ -516,7 +573,8 @@ def resize_image_with_crop_or_pad(image, target_height, target_width): dimension. Args: - image: 3-D tensor of shape `[height, width, channels]` + image: 4-D Tensor of shape `[batch, height, width, channels]` or + 3-D Tensor of shape `[height, width, channels]`. target_height: Target height. target_width: Target width. @@ -524,13 +582,27 @@ def resize_image_with_crop_or_pad(image, target_height, target_width): ValueError: if `target_height` or `target_width` are zero or negative. Returns: - Cropped and/or padded image of shape - `[target_height, target_width, channels]` + Cropped and/or padded image. + If `images` was 4-D, a 4-D float Tensor of shape + `[batch, new_height, new_width, channels]`. + If `images` was 3-D, a 3-D float Tensor of shape + `[new_height, new_width, channels]`. """ image = ops.convert_to_tensor(image, name='image') + image_shape = image.get_shape() + is_batch = True + if image_shape.ndims == 3: + is_batch = False + image = array_ops.expand_dims(image, 0) + elif image_shape.ndims is None: + is_batch = False + image = array_ops.expand_dims(image, 0) + image.set_shape([None] * 4) + elif image_shape.ndims != 4: + raise ValueError('\'image\' must have either 3 or 4 dimensions.') assert_ops = [] - assert_ops += _Check3DImage(image, require_static=False) + assert_ops += _CheckAtLeast3DImage(image, require_static=False) assert_ops += _assert(target_width > 0, ValueError, 'target_width must be > 0.') assert_ops += _assert(target_height > 0, ValueError, @@ -563,7 +635,7 @@ def resize_image_with_crop_or_pad(image, target_height, target_width): else: return x == y - height, width, _ = _ImageDimensions(image) + _, height, width, _ = _ImageDimensions(image, rank=4) width_diff = target_width - width offset_crop_width = max_(-width_diff // 2, 0) offset_pad_width = max_(width_diff // 2, 0) @@ -585,7 +657,7 @@ def resize_image_with_crop_or_pad(image, target_height, target_width): if resized.get_shape().ndims is None: raise ValueError('resized contains no shape.') - resized_height, resized_width, _ = _ImageDimensions(resized) + _, resized_height, resized_width, _ = _ImageDimensions(resized, rank=4) assert_ops = [] assert_ops += _assert(equal_(resized_height, target_height), ValueError, @@ -594,6 +666,10 @@ def resize_image_with_crop_or_pad(image, target_height, target_width): 'resized width is not correct.') resized = control_flow_ops.with_dependencies(assert_ops, resized) + + if not is_batch: + resized = array_ops.squeeze(resized, squeeze_dims=[0]) + return resized |