aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2016-02-05 09:41:06 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2016-02-05 17:04:22 -0800
commit3f06730be4c21d158b9714841098048bd48d89c6 (patch)
treeeaf20b4bcb683a9dd0ba95ef3a01afcba9b07ded
parent434b79ec286c13bc0fb8b65ff824cf7ea3ee0c3e (diff)
Enables `tf.image.resize_images()` to accept tensors for height, width.
Fixes #1001. Change: 113956896
m---------google/protobuf0
-rw-r--r--tensorflow/python/ops/image_ops.py34
-rw-r--r--tensorflow/python/ops/image_ops_test.py52
3 files changed, 81 insertions, 5 deletions
diff --git a/google/protobuf b/google/protobuf
-Subproject bd8a476510d17d3841ff2509fbd67b7f4b543c1
+Subproject 0906f5d18a2548024b511eadcbb4cfc0ca56cd6
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
index b65ecf9aeb..2eeef95d99 100644
--- a/tensorflow/python/ops/image_ops.py
+++ b/tensorflow/python/ops/image_ops.py
@@ -563,30 +563,53 @@ def resize_images(images,
_, height, width, depth = _ImageDimensions(images)
- if width == new_width and height == new_height:
+ # Handle tensor-valued sizes as well as Python integers.
+ try:
+ new_width = ops.convert_to_tensor(new_width, dtypes.int32,
+ name='new_width')
+ new_width.get_shape().assert_has_rank(0)
+ except (TypeError, ValueError):
+ raise ValueError('new_width must be a scalar integer')
+ try:
+ new_height = ops.convert_to_tensor(new_height, dtypes.int32,
+ name='new_height')
+ new_height.get_shape().assert_has_rank(0)
+ except (TypeError, ValueError):
+ raise ValueError('new_height must be a scalar integer')
+
+ new_width_const = tensor_util.constant_value(new_width)
+ new_height_const = tensor_util.constant_value(new_height)
+
+ if width == new_width_const and height == new_height_const:
if not is_batch:
images = array_ops.squeeze(images, squeeze_dims=[0])
return images
+ new_size = array_ops.pack([new_height, new_width])
+
if method == ResizeMethod.BILINEAR:
images = gen_image_ops.resize_bilinear(images,
- [new_height, new_width],
+ new_size,
align_corners=align_corners)
elif method == ResizeMethod.NEAREST_NEIGHBOR:
images = gen_image_ops.resize_nearest_neighbor(images,
- [new_height, new_width],
+ new_size,
align_corners=align_corners)
elif method == ResizeMethod.BICUBIC:
images = gen_image_ops.resize_bicubic(images,
- [new_height, new_width],
+ new_size,
align_corners=align_corners)
elif method == ResizeMethod.AREA:
images = gen_image_ops.resize_area(images,
- [new_height, new_width],
+ new_size,
align_corners=align_corners)
else:
raise ValueError('Resize method is not implemented.')
+ # NOTE(mrry): The shape functions for the resize ops cannot unpack
+ # the packed values in `new_size`, so set the shape here.
+ images.set_shape([None, new_height_const, new_width_const, None])
+
if not is_batch:
images = array_ops.squeeze(images, squeeze_dims=[0])
return images
@@ -779,6 +802,7 @@ ops.RegisterShape('AdjustContrastv2')(
def _ResizeShape(op):
"""Shape function for the resize_bilinear and resize_nearest_neighbor ops."""
input_shape = op.inputs[0].get_shape().with_rank(4)
+ unused_size_shape = op.inputs[1].get_shape().merge_with([2])
size = tensor_util.constant_value(op.inputs[1])
if size is not None:
height = size[0]
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index f44fe344e6..d09004556c 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -654,6 +654,58 @@ class ResizeImagesTest(test_util.TensorFlowTestCase):
newshape = yshape.eval()
self.assertAllEqual(single_shape, newshape)
+ def testTensorArguments(self):
+ img_shape = [1, 6, 4, 1]
+ single_shape = [6, 4, 1]
+ # This test is also conducted with int8, so 127 is the maximum
+ # value that can be used.
+ data = [127, 127, 64, 64,
+ 127, 127, 64, 64,
+ 64, 64, 127, 127,
+ 64, 64, 127, 127,
+ 50, 50, 100, 100,
+ 50, 50, 100, 100]
+ target_height = array_ops.placeholder(dtypes.int32)
+ target_width = array_ops.placeholder(dtypes.int32)
+
+ img_np = np.array(data, dtype=np.uint8).reshape(img_shape)
+
+ for opt in self.OPTIONS:
+ with self.test_session() as sess:
+ image = constant_op.constant(img_np, shape=img_shape)
+ y = image_ops.resize_images(image, target_height, target_width, opt)
+ yshape = array_ops.shape(y)
+ resized, newshape = sess.run([y, yshape], {target_height: 6,
+ target_width: 4})
+ self.assertAllEqual(img_shape, newshape)
+ self.assertAllClose(resized, img_np, atol=1e-5)
+
+ # Resizing with a single image must leave the shape unchanged also.
+ with self.test_session():
+ img_single = img_np.reshape(single_shape)
+ image = constant_op.constant(img_single, shape=single_shape)
+ y = image_ops.resize_images(image, target_height, target_width,
+ self.OPTIONS[0])
+ yshape = array_ops.shape(y)
+ newshape = yshape.eval(feed_dict={target_height: 6, target_width: 4})
+ self.assertAllEqual(single_shape, newshape)
+
+ # Incorrect shape.
+ with self.assertRaises(ValueError):
+ _ = image_ops.resize_images(
+ image, [12, 32], 4, image_ops.ResizeMethod.BILINEAR)
+ with self.assertRaises(ValueError):
+ _ = image_ops.resize_images(
+ image, 6, [12, 32], image_ops.ResizeMethod.BILINEAR)
+
+ # Incorrect dtypes.
+ with self.assertRaises(ValueError):
+ _ = image_ops.resize_images(
+ image, 6.0, 4, image_ops.ResizeMethod.BILINEAR)
+ with self.assertRaises(ValueError):
+ _ = image_ops.resize_images(
+ image, 6, 4.0, image_ops.ResizeMethod.BILINEAR)
+
def testResizeDown(self):
# This test is also conducted with int8, so 127 is the maximum
# value that can be used.