aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/image
diff options
context:
space:
mode:
authorGravatar Dan Ringwalt <ringwalt@google.com>2018-08-10 12:12:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-10 12:17:07 -0700
commitf7bf204cb361aaf238767c701ebb010b7f0ba986 (patch)
treebfbe156efea6bbabd47916a095774451dc8654e5 /tensorflow/contrib/image
parentd2bec062040bfa083195c06deef56981c6f0f946 (diff)
BEGIN_PUBLIC
Allow a different output shape from the input in tf.contrib.image.transform (#17011). END_PUBLIC RELNOTES: Allow a different output shape from the input in tf.contrib.image.transform. Thanks qyu@ for making the original change and fixing a few other prior issues! Automated rollback of commit 07fdb697d33478d7a72d09fc2371fa834e870b83 PiperOrigin-RevId: 208248183
Diffstat (limited to 'tensorflow/contrib/image')
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.cc33
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.h2
-rw-r--r--tensorflow/contrib/image/ops/image_ops.cc57
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/image_ops_test.py44
-rw-r--r--tensorflow/contrib/image/python/ops/image_ops.py52
5 files changed, 156 insertions, 32 deletions
diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc
index 022e17d139..693724b457 100644
--- a/tensorflow/contrib/image/kernels/image_ops.cc
+++ b/tensorflow/contrib/image/kernels/image_ops.cc
@@ -71,6 +71,7 @@ class ImageProjectiveTransform : public OpKernel {
void Compute(OpKernelContext* ctx) override {
const Tensor& images_t = ctx->input(0);
const Tensor& transform_t = ctx->input(1);
+ const Tensor& shape_t = ctx->input(2);
OP_REQUIRES(ctx, images_t.shape().dims() == 4,
errors::InvalidArgument("Input images must have rank 4"));
OP_REQUIRES(ctx,
@@ -81,11 +82,28 @@ class ImageProjectiveTransform : public OpKernel {
ProjectiveGenerator<Device, T>::kNumParameters),
errors::InvalidArgument(
"Input transform should be num_images x 8 or 1 x 8"));
- auto images = images_t.tensor<T, 4>();
- auto transform = transform_t.matrix<float>();
+ OP_REQUIRES(ctx, shape_t.dims() == 1,
+ errors::InvalidArgument("output shape must be 1-dimensional",
+ shape_t.shape().DebugString()));
+ OP_REQUIRES(ctx, shape_t.NumElements() == 2,
+ errors::InvalidArgument("output shape must have two elements",
+ shape_t.shape().DebugString()));
+ auto shape_vec = shape_t.vec<int32>();
+ int32 out_height = shape_vec(0);
+ int32 out_width = shape_vec(1);
+ OP_REQUIRES(ctx, out_height > 0 && out_width > 0,
+ errors::InvalidArgument("output dimensions must be positive"));
+
Tensor* output_t;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t));
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(
+ 0,
+ TensorShape({images_t.dim_size(0), out_height,
+ out_width, images_t.dim_size(3)}),
+ &output_t));
auto output = output_t->tensor<T, 4>();
+ auto images = images_t.tensor<T, 4>();
+ auto transform = transform_t.matrix<float>();
+
(FillProjectiveTransform<Device, T>(interpolation_))(
ctx->eigen_device<Device>(), &output, images, transform);
}
@@ -129,10 +147,11 @@ TF_CALL_double(DECLARE_FUNCTOR);
} // end namespace functor
-#define REGISTER(TYPE) \
- REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransform") \
- .Device(DEVICE_GPU) \
- .TypeConstraint<TYPE>("dtype"), \
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransform") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<TYPE>("dtype") \
+ .HostMemory("output_shape"), \
ImageProjectiveTransform<GPUDevice, TYPE>)
TF_CALL_uint8(REGISTER);
diff --git a/tensorflow/contrib/image/kernels/image_ops.h b/tensorflow/contrib/image/kernels/image_ops.h
index 209aa24548..6b63eed130 100644
--- a/tensorflow/contrib/image/kernels/image_ops.h
+++ b/tensorflow/contrib/image/kernels/image_ops.h
@@ -167,7 +167,7 @@ struct FillProjectiveTransform {
void operator()(const Device& device, OutputType* output,
const InputType& images,
const TransformsType& transform) const {
- output->device(device) = images.generate(
+ output->device(device) = output->generate(
ProjectiveGenerator<Device, T>(images, transform, interpolation_));
}
};
diff --git a/tensorflow/contrib/image/ops/image_ops.cc b/tensorflow/contrib/image/ops/image_ops.cc
index e59f1bf844..4969ac58f9 100644
--- a/tensorflow/contrib/image/ops/image_ops.cc
+++ b/tensorflow/contrib/image/ops/image_ops.cc
@@ -19,23 +19,66 @@ limitations under the License.
namespace tensorflow {
+using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
+namespace {
+
+// Sets output[0] to shape [batch_dim,height,width,channel_dim], where
+// height and width come from the size_tensor.
+Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim,
+ int size_input_idx, DimensionHandle channel_dim) {
+ // Verify shape of size input.
+ ShapeHandle size;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(size_input_idx), 1, &size));
+ DimensionHandle unused;
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 2, &unused));
+
+ // Get size values from the size tensor.
+ const Tensor* size_tensor = c->input_tensor(size_input_idx);
+ DimensionHandle width;
+ DimensionHandle height;
+ if (size_tensor == nullptr) {
+ width = c->UnknownDim();
+ height = c->UnknownDim();
+ } else {
+ // TODO(petewarden) - Remove once we have constant evaluation in C++ only.
+ if (size_tensor->dtype() != DT_INT32) {
+ return errors::InvalidArgument(
+ "Bad size input type for SetOutputToSizedImage: Expected DT_INT32 "
+ "but got ",
+ DataTypeString(size_tensor->dtype()), " for input #", size_input_idx,
+ " in ", c->DebugString());
+ }
+ auto vec = size_tensor->vec<int32>();
+ height = c->MakeDim(vec(0));
+ width = c->MakeDim(vec(1));
+ }
+ c->set_output(0, c->MakeShape({batch_dim, height, width, channel_dim}));
+ return Status::OK();
+}
+
+// TODO(qyu): Move this to core/framework/common_shape_fns.h
+Status ResizeShapeFn(InferenceContext* c) {
+ ShapeHandle input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
+ return SetOutputToSizedImage(c, c->Dim(input, 0), 2 /* size_input_idx */,
+ c->Dim(input, 3));
+}
+
+} // namespace
+
// TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc.
// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
-// TODO(ringwalt): Add an "output_shape" argument. This is sufficient to
-// implement "same" and "valid" modes in the Python function.
REGISTER_OP("ImageProjectiveTransform")
.Input("images: dtype")
.Input("transforms: float32")
+ .Input("output_shape: int32")
.Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
.Attr("interpolation: string")
.Output("transformed_images: dtype")
- .SetShapeFn([](InferenceContext* c) {
- c->set_output(0, c->input(0));
- return Status::OK();
- })
+ .SetShapeFn(ResizeShapeFn)
.Doc(R"doc(
Applies the given transform to each of the images.
@@ -49,7 +92,7 @@ If one row of `transforms` is `[a0, a1, a2, b0, b1, b2, c0, c1]`, then it maps
the *output* point `(x, y)` to a transformed *input* point
`(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, where
`k = c0 x + c1 y + 1`. If the transformed point lays outside of the input
-image, the output pixel is set to 0. The output is the same size as the input,
+image, the output pixel is set to 0.
images: 4D `Tensor`, input image(s) in NHWC format.
transforms: 2D `Tensor`, projective transform(s) to apply to the image(s).
diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
index 62a22dcf34..f588eae923 100644
--- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
+++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.platform import googletest
_DTYPES = set(
@@ -194,6 +195,19 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
[0.0, 149, 233, 149, 0.0],
[0.0, 0.0, 87., 0.0, 0.0]])
+ def test_rotate_static_shape(self):
+ image = array_ops.diag([1., 2., 3.])
+ result = image_ops.rotate(
+ image, random_ops.random_uniform((), -1, 1), interpolation="BILINEAR")
+ self.assertEqual(image.get_shape(), result.get_shape())
+
+ def test_transform_static_output_shape(self):
+ image = constant_op.constant([[1., 2.], [3., 4.]])
+ result = image_ops.transform(
+ image, random_ops.random_uniform([8], -1, 1),
+ output_shape=constant_op.constant([3, 5]))
+ self.assertAllEqual([3, 5], result.get_shape())
+
def _test_grad(self, shape_to_test):
with self.test_session():
test_image_shape = shape_to_test
@@ -213,10 +227,40 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
x_init_value=test_image)
self.assertLess(left_err, 1e-10)
+ def _test_grad_different_shape(self, input_shape, output_shape):
+ with self.test_session():
+ test_image_shape = input_shape
+ test_image = np.random.randn(*test_image_shape)
+ test_image_tensor = constant_op.constant(
+ test_image, shape=test_image_shape)
+ test_transform = image_ops.angles_to_projective_transforms(
+ np.pi / 2, 4, 4)
+
+ if len(output_shape) == 2:
+ resize_shape = output_shape
+ elif len(output_shape) == 3:
+ resize_shape = output_shape[0:2]
+ elif len(output_shape) == 4:
+ resize_shape = output_shape[1:3]
+ output = image_ops.transform(
+ images=test_image_tensor,
+ transforms=test_transform,
+ output_shape=resize_shape)
+ left_err = gradient_checker.compute_gradient_error(
+ test_image_tensor,
+ test_image_shape,
+ output,
+ output_shape,
+ x_init_value=test_image)
+ self.assertLess(left_err, 1e-10)
+
def test_grad(self):
self._test_grad([16, 16])
self._test_grad([4, 12, 12])
self._test_grad([3, 4, 12, 12])
+ self._test_grad_different_shape([16, 16], [8, 8])
+ self._test_grad_different_shape([4, 12, 3], [8, 24, 3])
+ self._test_grad_different_shape([3, 4, 12, 3], [3, 8, 24, 3])
class BipartiteMatchTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py
index 86b0ffe9a0..e7a09041ad 100644
--- a/tensorflow/contrib/image/python/ops/image_ops.py
+++ b/tensorflow/contrib/image/python/ops/image_ops.py
@@ -23,6 +23,7 @@ from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
@@ -40,6 +41,9 @@ ops.RegisterShape("ImageConnectedComponents")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn)
+# TODO(ringwalt): Support a "reshape" (name used by SciPy) or "expand" (name
+# used by PIL, maybe more readable) mode, which determines the correct
+# output_shape and translation for the transform.
def rotate(images, angles, interpolation="NEAREST", name=None):
"""Rotate image(s) counterclockwise by the passed angle(s) in radians.
@@ -213,7 +217,11 @@ def translations_to_projective_transforms(translations, name=None):
axis=1)
-def transform(images, transforms, interpolation="NEAREST", name=None):
+def transform(images,
+ transforms,
+ interpolation="NEAREST",
+ output_shape=None,
+ name=None):
"""Applies the given transform(s) to the image(s).
Args:
@@ -230,6 +238,10 @@ def transform(images, transforms, interpolation="NEAREST", name=None):
the transform mapping input points to output points. Note that gradients
are not backpropagated into transformation parameters.
interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
+ output_shape: Output dimesion after the transform, [height, width].
+ If None, output is the same size as input image.
+
+ name: The name of the op.
Returns:
Image(s) with the same type and shape as `images`, with the given
@@ -238,6 +250,7 @@ def transform(images, transforms, interpolation="NEAREST", name=None):
Raises:
TypeError: If `image` is an invalid type.
+ ValueError: If output shape is not 1-D int32 Tensor.
"""
with ops.name_scope(name, "transform"):
image_or_images = ops.convert_to_tensor(images, name="images")
@@ -256,6 +269,17 @@ def transform(images, transforms, interpolation="NEAREST", name=None):
else:
raise TypeError("Images should have rank between 2 and 4.")
+ if output_shape is None:
+ output_shape = tensor_util.constant_value(
+ array_ops.shape(images)[1:3]) or array_ops.shape(images)[1:3]
+
+ output_shape = ops.convert_to_tensor(
+ output_shape, dtypes.int32, name="output_shape")
+
+ if not output_shape.get_shape().is_compatible_with([2]):
+ raise ValueError("output_shape must be a 1-D Tensor of 2 elements: "
+ "new_height, new_width")
+
if len(transform_or_transforms.get_shape()) == 1:
transforms = transform_or_transforms[None]
elif transform_or_transforms.get_shape().ndims is None:
@@ -265,8 +289,12 @@ def transform(images, transforms, interpolation="NEAREST", name=None):
transforms = transform_or_transforms
else:
raise TypeError("Transforms should have rank 1 or 2.")
+
output = gen_image_ops.image_projective_transform(
- images, transforms, interpolation=interpolation.upper())
+ images,
+ output_shape=output_shape,
+ transforms=transforms,
+ interpolation=interpolation.upper())
if len(image_or_images.get_shape()) == 2:
return output[0, :, :, 0]
elif len(image_or_images.get_shape()) == 3:
@@ -376,14 +404,6 @@ def _image_projective_transform_grad(op, grad):
if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
raise TypeError("Invalid dtype %s." % image_or_images.dtype)
- if len(image_or_images.get_shape()) == 2:
- images = image_or_images[None, :, :, None]
- elif len(image_or_images.get_shape()) == 3:
- images = image_or_images[None, :, :, :]
- elif len(image_or_images.get_shape()) == 4:
- images = image_or_images
- else:
- raise TypeError("Images should have rank between 2 and 4")
if len(transform_or_transforms.get_shape()) == 1:
transforms = transform_or_transforms[None]
elif len(transform_or_transforms.get_shape()) == 2:
@@ -396,13 +416,11 @@ def _image_projective_transform_grad(op, grad):
inverse = linalg_ops.matrix_inverse(transforms)
transforms = matrices_to_flat_transforms(inverse)
output = gen_image_ops.image_projective_transform(
- grad, transforms, interpolation=interpolation)
- if len(image_or_images.get_shape()) == 2:
- return [output[0, :, :, 0], None]
- elif len(image_or_images.get_shape()) == 3:
- return [output[0, :, :, :], None]
- else:
- return [output, None]
+ images=grad,
+ transforms=transforms,
+ output_shape=array_ops.shape(image_or_images)[1:3],
+ interpolation=interpolation)
+ return [output, None, None]
def bipartite_match(distance_mat,