aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/image_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/image_ops.py')
-rw-r--r--tensorflow/python/ops/image_ops.py34
1 files changed, 29 insertions, 5 deletions
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
index b65ecf9aeb..2eeef95d99 100644
--- a/tensorflow/python/ops/image_ops.py
+++ b/tensorflow/python/ops/image_ops.py
@@ -563,30 +563,53 @@ def resize_images(images,
_, height, width, depth = _ImageDimensions(images)
- if width == new_width and height == new_height:
+ # Handle tensor-valued sizes as well as Python integers.
+ try:
+ new_width = ops.convert_to_tensor(new_width, dtypes.int32,
+ name='new_width')
+ new_width.get_shape().assert_has_rank(0)
+ except (TypeError, ValueError):
+ raise ValueError('new_width must be a scalar integer')
+ try:
+ new_height = ops.convert_to_tensor(new_height, dtypes.int32,
+ name='new_height')
+ new_height.get_shape().assert_has_rank(0)
+ except (TypeError, ValueError):
+ raise ValueError('new_height must be a scalar integer')
+
+ new_width_const = tensor_util.constant_value(new_width)
+ new_height_const = tensor_util.constant_value(new_height)
+
+ if width == new_width_const and height == new_height_const:
if not is_batch:
images = array_ops.squeeze(images, squeeze_dims=[0])
return images
+ new_size = array_ops.pack([new_height, new_width])
+
if method == ResizeMethod.BILINEAR:
images = gen_image_ops.resize_bilinear(images,
- [new_height, new_width],
+ new_size,
align_corners=align_corners)
elif method == ResizeMethod.NEAREST_NEIGHBOR:
images = gen_image_ops.resize_nearest_neighbor(images,
- [new_height, new_width],
+ new_size,
align_corners=align_corners)
elif method == ResizeMethod.BICUBIC:
images = gen_image_ops.resize_bicubic(images,
- [new_height, new_width],
+ new_size,
align_corners=align_corners)
elif method == ResizeMethod.AREA:
images = gen_image_ops.resize_area(images,
- [new_height, new_width],
+ new_size,
align_corners=align_corners)
else:
raise ValueError('Resize method is not implemented.')
+ # NOTE(mrry): The shape functions for the resize ops cannot unpack
+ # the packed values in `new_size`, so set the shape here.
+ images.set_shape([None, new_height_const, new_width_const, None])
+
if not is_batch:
images = array_ops.squeeze(images, squeeze_dims=[0])
return images
@@ -779,6 +802,7 @@ ops.RegisterShape('AdjustContrastv2')(
def _ResizeShape(op):
"""Shape function for the resize_bilinear and resize_nearest_neighbor ops."""
input_shape = op.inputs[0].get_shape().with_rank(4)
+ unused_size_shape = op.inputs[1].get_shape().merge_with([2])
size = tensor_util.constant_value(op.inputs[1])
if size is not None:
height = size[0]