aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/image
diff options
context:
space:
mode:
authorGravatar Dan Ringwalt <ringwalt@google.com>2018-02-05 10:36:24 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-05 10:40:26 -0800
commita9034babfd0100f0b998e900d074dd3d839d18bc (patch)
tree8c9b62ef0717c4f8045154cfd31c892c5756ff78 /tensorflow/contrib/image
parent2fde0f219c98dca9e75c9e5f95157d7b0edce746 (diff)
Make flat_transforms_to_matrices and matrices_to_flat_transforms public (#781).
PiperOrigin-RevId: 184549704
Diffstat (limited to 'tensorflow/contrib/image')
-rw-r--r--tensorflow/contrib/image/python/ops/image_ops.py87
1 files changed, 66 insertions, 21 deletions
diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py
index 6122ee5805..c139ae89d8 100644
--- a/tensorflow/contrib/image/python/ops/image_ops.py
+++ b/tensorflow/contrib/image/python/ops/image_ops.py
@@ -290,31 +290,76 @@ def compose_transforms(*transforms):
"""
assert transforms, "transforms cannot be empty"
with ops.name_scope("compose_transforms"):
- composed = _flat_transforms_to_matrices(transforms[0])
+ composed = flat_transforms_to_matrices(transforms[0])
for tr in transforms[1:]:
# Multiply batches of matrices.
- composed = math_ops.matmul(composed, _flat_transforms_to_matrices(tr))
- return _transform_matrices_to_flat(composed)
+ composed = math_ops.matmul(composed, flat_transforms_to_matrices(tr))
+ return matrices_to_flat_transforms(composed)
-def _flat_transforms_to_matrices(transforms):
- # Make the transform(s) 2D in case the input is a single transform.
- transforms = array_ops.reshape(transforms, constant_op.constant([-1, 8]))
- num_transforms = array_ops.shape(transforms)[0]
- # Add a column of ones for the implicit last entry in the matrix.
- return array_ops.reshape(
- array_ops.concat(
- [transforms, array_ops.ones([num_transforms, 1])], axis=1),
- constant_op.constant([-1, 3, 3]))
+def flat_transforms_to_matrices(transforms):
+ """Converts `tf.contrib.image` projective transforms to affine matrices.
+ Note that the output matrices map output coordinates to input coordinates. For
+ the forward transformation matrix, call `tf.linalg.inv` on the result.
-def _transform_matrices_to_flat(transform_matrices):
- # Flatten each matrix.
- transforms = array_ops.reshape(transform_matrices,
- constant_op.constant([-1, 9]))
- # Divide each matrix by the last entry (normally 1).
- transforms /= transforms[:, 8:9]
- return transforms[:, :8]
+ Args:
+ transforms: Vector of length 8, or batches of transforms with shape
+ `(N, 8)`.
+
+ Returns:
+ 3D tensor of matrices with shape `(N, 3, 3)`. The output matrices map the
+ *output coordinates* (in homogeneous coordinates) of each transform to the
+ corresponding *input coordinates*.
+
+ Raises:
+ ValueError: If `transforms` have an invalid shape.
+ """
+ with ops.name_scope("flat_transforms_to_matrices"):
+ transforms = ops.convert_to_tensor(transforms, name="transforms")
+ if transforms.shape.ndims not in (1, 2):
+ raise ValueError("Transforms should be 1D or 2D, got: %s" % transforms)
+ # Make the transform(s) 2D in case the input is a single transform.
+ transforms = array_ops.reshape(transforms, constant_op.constant([-1, 8]))
+ num_transforms = array_ops.shape(transforms)[0]
+ # Add a column of ones for the implicit last entry in the matrix.
+ return array_ops.reshape(
+ array_ops.concat(
+ [transforms, array_ops.ones([num_transforms, 1])], axis=1),
+ constant_op.constant([-1, 3, 3]))
+
+
+def matrices_to_flat_transforms(transform_matrices):
+ """Converts affine matrices to `tf.contrib.image` projective transforms.
+
+ Note that we expect matrices that map output coordinates to input coordinates.
+ To convert forward transformation matrices, call `tf.linalg.inv` on the
+ matrices and use the result here.
+
+ Args:
+ transform_matrices: One or more affine transformation matrices, for the
+ reverse transformation in homogeneous coordinates. Shape `(3, 3)` or
+ `(N, 3, 3)`.
+
+ Returns:
+ 2D tensor of flat transforms with shape `(N, 8)`, which may be passed into
+ `tf.contrib.image.transform`.
+
+ Raises:
+ ValueError: If `transform_matrices` have an invalid shape.
+ """
+ with ops.name_scope("matrices_to_flat_transforms"):
+ transform_matrices = ops.convert_to_tensor(
+ transform_matrices, name="transform_matrices")
+ if transform_matrices.shape.ndims not in (2, 3):
+ raise ValueError(
+ "Matrices should be 2D or 3D, got: %s" % transform_matrices)
+ # Flatten each matrix.
+ transforms = array_ops.reshape(transform_matrices,
+ constant_op.constant([-1, 9]))
+ # Divide each matrix by the last entry (normally 1).
+ transforms /= transforms[:, 8:9]
+ return transforms[:, :8]
@ops.RegisterGradient("ImageProjectiveTransform")
@@ -346,9 +391,9 @@ def _image_projective_transform_grad(op, grad):
raise TypeError("Transforms should have rank 1 or 2.")
# Invert transformations
- transforms = _flat_transforms_to_matrices(transforms=transforms)
+ transforms = flat_transforms_to_matrices(transforms=transforms)
inverse = linalg_ops.matrix_inverse(transforms)
- transforms = _transform_matrices_to_flat(inverse)
+ transforms = matrices_to_flat_transforms(inverse)
output = gen_image_ops.image_projective_transform(
grad, transforms, interpolation=interpolation)
if len(image_or_images.get_shape()) == 2: