aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/image/python/ops/image_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/image/python/ops/image_ops.py')
-rw-r--r--tensorflow/contrib/image/python/ops/image_ops.py39
1 files changed, 17 insertions, 22 deletions
diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py
index d3c114a88d..cd984c8054 100644
--- a/tensorflow/contrib/image/python/ops/image_ops.py
+++ b/tensorflow/contrib/image/python/ops/image_ops.py
@@ -212,11 +212,7 @@ def translations_to_projective_transforms(translations, name=None):
axis=1)
-def transform(images,
- transforms,
- interpolation="NEAREST",
- output_shape=None,
- name=None):
+def transform(images, transforms, interpolation="NEAREST", name=None):
"""Applies the given transform(s) to the image(s).
Args:
@@ -233,10 +229,6 @@ def transform(images,
the transform mapping input points to output points. Note that gradients
are not backpropagated into transformation parameters.
interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
- output_shape: Output dimesion after the transform, [height, width].
- If None, output is the same size as input image.
-
- name: The name of the op.
Returns:
Image(s) with the same type and shape as `images`, with the given
@@ -263,13 +255,6 @@ def transform(images,
else:
raise TypeError("Images should have rank between 2 and 4.")
- if output_shape is None:
- output_shape = array_ops.shape(images)[1:3]
- elif len(output_shape) != 2:
- raise TypeError(
- "output_shape must either be None or a vector of 2 elements. %s" %
- str(output_shape))
-
if len(transform_or_transforms.get_shape()) == 1:
transforms = transform_or_transforms[None]
elif transform_or_transforms.get_shape().ndims is None:
@@ -280,7 +265,7 @@ def transform(images,
else:
raise TypeError("Transforms should have rank 1 or 2.")
output = gen_image_ops.image_projective_transform(
- images, transforms, output_shape, interpolation=interpolation.upper())
+ images, transforms, interpolation=interpolation.upper())
if len(image_or_images.get_shape()) == 2:
return output[0, :, :, 0]
elif len(image_or_images.get_shape()) == 3:
@@ -390,6 +375,14 @@ def _image_projective_transform_grad(op, grad):
if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
raise TypeError("Invalid dtype %s." % image_or_images.dtype)
+ if len(image_or_images.get_shape()) == 2:
+ images = image_or_images[None, :, :, None]
+ elif len(image_or_images.get_shape()) == 3:
+ images = image_or_images[None, :, :, :]
+ elif len(image_or_images.get_shape()) == 4:
+ images = image_or_images
+ else:
+ raise TypeError("Images should have rank between 2 and 4")
if len(transform_or_transforms.get_shape()) == 1:
transforms = transform_or_transforms[None]
elif len(transform_or_transforms.get_shape()) == 2:
@@ -402,11 +395,13 @@ def _image_projective_transform_grad(op, grad):
inverse = linalg_ops.matrix_inverse(transforms)
transforms = matrices_to_flat_transforms(inverse)
output = gen_image_ops.image_projective_transform(
- images=grad,
- transforms=transforms,
- output_shape=array_ops.shape(image_or_images)[1:3],
- interpolation=interpolation)
- return [output, None, None]
+ grad, transforms, interpolation=interpolation)
+ if len(image_or_images.get_shape()) == 2:
+ return [output[0, :, :, 0], None]
+ elif len(image_or_images.get_shape()) == 3:
+ return [output[0, :, :, :], None]
+ else:
+ return [output, None]
def bipartite_match(distance_mat,