diff options
Diffstat (limited to 'tensorflow/python/ops/image_ops.py')
-rw-r--r-- | tensorflow/python/ops/image_ops.py | 34 |
1 files changed, 29 insertions, 5 deletions
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py index b65ecf9aeb..2eeef95d99 100644 --- a/tensorflow/python/ops/image_ops.py +++ b/tensorflow/python/ops/image_ops.py @@ -563,30 +563,53 @@ def resize_images(images, _, height, width, depth = _ImageDimensions(images) - if width == new_width and height == new_height: + # Handle tensor-valued sizes as well as Python integers. + try: + new_width = ops.convert_to_tensor(new_width, dtypes.int32, + name='new_width') + new_width.get_shape().assert_has_rank(0) + except (TypeError, ValueError): + raise ValueError('new_width must be a scalar integer') + try: + new_height = ops.convert_to_tensor(new_height, dtypes.int32, + name='new_height') + new_height.get_shape().assert_has_rank(0) + except (TypeError, ValueError): + raise ValueError('new_height must be a scalar integer') + + new_width_const = tensor_util.constant_value(new_width) + new_height_const = tensor_util.constant_value(new_height) + + if width == new_width_const and height == new_height_const: if not is_batch: images = array_ops.squeeze(images, squeeze_dims=[0]) return images + new_size = array_ops.pack([new_height, new_width]) + if method == ResizeMethod.BILINEAR: images = gen_image_ops.resize_bilinear(images, - [new_height, new_width], + new_size, align_corners=align_corners) elif method == ResizeMethod.NEAREST_NEIGHBOR: images = gen_image_ops.resize_nearest_neighbor(images, - [new_height, new_width], + new_size, align_corners=align_corners) elif method == ResizeMethod.BICUBIC: images = gen_image_ops.resize_bicubic(images, - [new_height, new_width], + new_size, align_corners=align_corners) elif method == ResizeMethod.AREA: images = gen_image_ops.resize_area(images, - [new_height, new_width], + new_size, align_corners=align_corners) else: raise ValueError('Resize method is not implemented.') + # NOTE(mrry): The shape functions for the resize ops cannot unpack + # the packed values in `new_size`, so set the shape here. + images.set_shape([None, new_height_const, new_width_const, None]) + if not is_batch: images = array_ops.squeeze(images, squeeze_dims=[0]) return images @@ -779,6 +802,7 @@ ops.RegisterShape('AdjustContrastv2')( def _ResizeShape(op): """Shape function for the resize_bilinear and resize_nearest_neighbor ops.""" input_shape = op.inputs[0].get_shape().with_rank(4) + unused_size_shape = op.inputs[1].get_shape().merge_with([2]) size = tensor_util.constant_value(op.inputs[1]) if size is not None: height = size[0] |