diff options
Diffstat (limited to 'tensorflow/contrib/image/python/ops/image_ops.py')
-rw-r--r-- | tensorflow/contrib/image/python/ops/image_ops.py | 39 |
1 files changed, 17 insertions, 22 deletions
diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index d3c114a88d..cd984c8054 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -212,11 +212,7 @@ def translations_to_projective_transforms(translations, name=None): axis=1) -def transform(images, - transforms, - interpolation="NEAREST", - output_shape=None, - name=None): +def transform(images, transforms, interpolation="NEAREST", name=None): """Applies the given transform(s) to the image(s). Args: @@ -233,10 +229,6 @@ def transform(images, the transform mapping input points to output points. Note that gradients are not backpropagated into transformation parameters. interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". - output_shape: Output dimesion after the transform, [height, width]. - If None, output is the same size as input image. - - name: The name of the op. Returns: Image(s) with the same type and shape as `images`, with the given @@ -263,13 +255,6 @@ def transform(images, else: raise TypeError("Images should have rank between 2 and 4.") - if output_shape is None: - output_shape = array_ops.shape(images)[1:3] - elif len(output_shape) != 2: - raise TypeError( - "output_shape must either be None or a vector of 2 elements. %s" % - str(output_shape)) - if len(transform_or_transforms.get_shape()) == 1: transforms = transform_or_transforms[None] elif transform_or_transforms.get_shape().ndims is None: @@ -280,7 +265,7 @@ def transform(images, else: raise TypeError("Transforms should have rank 1 or 2.") output = gen_image_ops.image_projective_transform( - images, transforms, output_shape, interpolation=interpolation.upper()) + images, transforms, interpolation=interpolation.upper()) if len(image_or_images.get_shape()) == 2: return output[0, :, :, 0] elif len(image_or_images.get_shape()) == 3: @@ -390,6 +375,14 @@ def _image_projective_transform_grad(op, grad): if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: raise TypeError("Invalid dtype %s." % image_or_images.dtype) + if len(image_or_images.get_shape()) == 2: + images = image_or_images[None, :, :, None] + elif len(image_or_images.get_shape()) == 3: + images = image_or_images[None, :, :, :] + elif len(image_or_images.get_shape()) == 4: + images = image_or_images + else: + raise TypeError("Images should have rank between 2 and 4") if len(transform_or_transforms.get_shape()) == 1: transforms = transform_or_transforms[None] elif len(transform_or_transforms.get_shape()) == 2: @@ -402,11 +395,13 @@ def _image_projective_transform_grad(op, grad): inverse = linalg_ops.matrix_inverse(transforms) transforms = matrices_to_flat_transforms(inverse) output = gen_image_ops.image_projective_transform( - images=grad, - transforms=transforms, - output_shape=array_ops.shape(image_or_images)[1:3], - interpolation=interpolation) - return [output, None, None] + grad, transforms, interpolation=interpolation) + if len(image_or_images.get_shape()) == 2: + return [output[0, :, :, 0], None] + elif len(image_or_images.get_shape()) == 3: + return [output[0, :, :, :], None] + else: + return [output, None] def bipartite_match(distance_mat, |