aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/image
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-19 13:21:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-19 13:24:36 -0700
commit7f1e64eb94447665047fac16c67b5351bcf3c8a3 (patch)
tree7f7e3ecadfee637eef39df58ff414089cc312716 /tensorflow/contrib/image
parent55706e693ab20f6200061fb73067cbf27707cccd (diff)
Allow output has a different shape from input in the image.transform (#17011).
PiperOrigin-RevId: 193564222
Diffstat (limited to 'tensorflow/contrib/image')
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.cc7
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.h2
-rw-r--r--tensorflow/contrib/image/ops/image_ops.cc52
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/image_ops_test.py30
-rw-r--r--tensorflow/contrib/image/python/ops/image_ops.py39
5 files changed, 107 insertions, 23 deletions
diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc
index c2e32da133..ae4b1ba62a 100644
--- a/tensorflow/contrib/image/kernels/image_ops.cc
+++ b/tensorflow/contrib/image/kernels/image_ops.cc
@@ -70,6 +70,7 @@ class ImageProjectiveTransform : public OpKernel {
void Compute(OpKernelContext* ctx) override {
const Tensor& images_t = ctx->input(0);
const Tensor& transform_t = ctx->input(1);
+ const Tensor& output_dim = ctx->input(2);
OP_REQUIRES(ctx, images_t.shape().dims() == 4,
errors::InvalidArgument("Input images must have rank 4"));
OP_REQUIRES(ctx,
@@ -83,7 +84,11 @@ class ImageProjectiveTransform : public OpKernel {
auto images = images_t.tensor<T, 4>();
auto transform = transform_t.matrix<float>();
Tensor* output_t;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t));
+ // Image is NHWC format.
+ auto output_shape = images_t.shape();
+ output_shape.set_dim(1, output_dim.vec<int>()(0));
+ output_shape.set_dim(2, output_dim.vec<int>()(1));
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output_t));
auto output = output_t->tensor<T, 4>();
(FillProjectiveTransform<Device, T>(interpolation_))(
ctx->eigen_device<Device>(), &output, images, transform);
diff --git a/tensorflow/contrib/image/kernels/image_ops.h b/tensorflow/contrib/image/kernels/image_ops.h
index ad50133061..2320329b92 100644
--- a/tensorflow/contrib/image/kernels/image_ops.h
+++ b/tensorflow/contrib/image/kernels/image_ops.h
@@ -161,7 +161,7 @@ struct FillProjectiveTransform {
void operator()(const Device& device, OutputType* output,
const InputType& images,
const TransformsType& transform) const {
- output->device(device) = images.generate(
+ output->device(device) = output->generate(
ProjectiveGenerator<Device, T>(images, transform, interpolation_));
}
};
diff --git a/tensorflow/contrib/image/ops/image_ops.cc b/tensorflow/contrib/image/ops/image_ops.cc
index 68771b3d05..4c6d8c0d19 100644
--- a/tensorflow/contrib/image/ops/image_ops.cc
+++ b/tensorflow/contrib/image/ops/image_ops.cc
@@ -19,9 +19,55 @@ limitations under the License.
namespace tensorflow {
+using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
+namespace {
+
+// Sets output[0] to shape [batch_dim,height,width,channel_dim], where
+// height and width come from the size_tensor.
+Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim,
+ int size_input_idx, DimensionHandle channel_dim) {
+ // Verify shape of size input.
+ ShapeHandle size;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(size_input_idx), 1, &size));
+ DimensionHandle unused;
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 2, &unused));
+
+ // Get size values from the size tensor.
+ const Tensor* size_tensor = c->input_tensor(size_input_idx);
+ DimensionHandle width;
+ DimensionHandle height;
+ if (size_tensor == nullptr) {
+ width = c->UnknownDim();
+ height = c->UnknownDim();
+ } else {
+ // TODO(petewarden) - Remove once we have constant evaluation in C++ only.
+ if (size_tensor->dtype() != DT_INT32) {
+ return errors::InvalidArgument(
+ "Bad size input type for SetOutputToSizedImage: Expected DT_INT32 "
+ "but got ",
+ DataTypeString(size_tensor->dtype()), " for input #", size_input_idx,
+ " in ", c->DebugString());
+ }
+ auto vec = size_tensor->vec<int32>();
+ height = c->MakeDim(vec(0));
+ width = c->MakeDim(vec(1));
+ }
+ c->set_output(0, c->MakeShape({batch_dim, height, width, channel_dim}));
+ return Status::OK();
+}
+
+Status ResizeShapeFn(InferenceContext* c) {
+ ShapeHandle input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
+ return SetOutputToSizedImage(c, c->Dim(input, 0), 2 /* size_input_idx */,
+ c->Dim(input, 3));
+}
+
+} // namespace
+
// TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc.
// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
// TODO(ringwalt): Add an "output_shape" argument. This is sufficient to
@@ -29,13 +75,11 @@ using shape_inference::ShapeHandle;
REGISTER_OP("ImageProjectiveTransform")
.Input("images: dtype")
.Input("transforms: float32")
+ .Input("output_shape: int32")
.Attr("dtype: {uint8, int32, int64, float32, float64}")
.Attr("interpolation: string")
.Output("transformed_images: dtype")
- .SetShapeFn([](InferenceContext* c) {
- c->set_output(0, c->input(0));
- return Status::OK();
- })
+ .SetShapeFn(ResizeShapeFn)
.Doc(R"doc(
Applies the given transform to each of the images.
diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
index b50177ae56..c0151d320f 100644
--- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
+++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
@@ -195,10 +195,40 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
x_init_value=test_image)
self.assertLess(left_err, 1e-10)
+ def _test_grad_different_shape(self, input_shape, output_shape):
+ with self.test_session():
+ test_image_shape = input_shape
+ test_image = np.random.randn(*test_image_shape)
+ test_image_tensor = constant_op.constant(
+ test_image, shape=test_image_shape)
+ test_transform = image_ops.angles_to_projective_transforms(
+ np.pi / 2, 4, 4)
+
+ if len(output_shape) == 2:
+ resize_shape = output_shape
+ elif len(output_shape) == 3:
+ resize_shape = output_shape[0:2]
+ elif len(output_shape) == 4:
+ resize_shape = output_shape[1:3]
+ output = image_ops.transform(
+ images=test_image_tensor,
+ transforms=test_transform,
+ output_shape=resize_shape)
+ left_err = gradient_checker.compute_gradient_error(
+ test_image_tensor,
+ test_image_shape,
+ output,
+ output_shape,
+ x_init_value=test_image)
+ self.assertLess(left_err, 1e-10)
+
def test_grad(self):
self._test_grad([16, 16])
self._test_grad([4, 12, 12])
self._test_grad([3, 4, 12, 12])
+ self._test_grad_different_shape([16, 16], [8, 8])
+ self._test_grad_different_shape([4, 12, 3], [8, 24, 3])
+ self._test_grad_different_shape([3, 4, 12, 3], [3, 8, 24, 3])
class BipartiteMatchTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py
index c139ae89d8..0cb7bdc75d 100644
--- a/tensorflow/contrib/image/python/ops/image_ops.py
+++ b/tensorflow/contrib/image/python/ops/image_ops.py
@@ -212,7 +212,11 @@ def translations_to_projective_transforms(translations, name=None):
axis=1)
-def transform(images, transforms, interpolation="NEAREST", name=None):
+def transform(images,
+ transforms,
+ output_shape=None,
+ interpolation="NEAREST",
+ name=None):
"""Applies the given transform(s) to the image(s).
Args:
@@ -228,7 +232,10 @@ def transform(images, transforms, interpolation="NEAREST", name=None):
where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to
the transform mapping input points to output points. Note that gradients
are not backpropagated into transformation parameters.
+ output_shape: Output dimesion after the transform, [height, width].
+ If None, output is the same size as input image.
interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
+ name: The name of the op.
Returns:
Image(s) with the same type and shape as `images`, with the given
@@ -255,6 +262,14 @@ def transform(images, transforms, interpolation="NEAREST", name=None):
else:
raise TypeError("Images should have rank between 2 and 4.")
+ if output_shape is None:
+ output_shape = images.get_shape()[1:3]
+ elif len(output_shape) != 2:
+ raise TypeError(
+ "output_shape must either be None or a vector of 2 elements.")
+ output_shape = ops.convert_to_tensor(
+ output_shape, name="output_shape", dtype=dtypes.int32)
+
if len(transform_or_transforms.get_shape()) == 1:
transforms = transform_or_transforms[None]
elif transform_or_transforms.get_shape().ndims is None:
@@ -265,7 +280,7 @@ def transform(images, transforms, interpolation="NEAREST", name=None):
else:
raise TypeError("Transforms should have rank 1 or 2.")
output = gen_image_ops.image_projective_transform(
- images, transforms, interpolation=interpolation.upper())
+ images, transforms, output_shape, interpolation=interpolation.upper())
if len(image_or_images.get_shape()) == 2:
return output[0, :, :, 0]
elif len(image_or_images.get_shape()) == 3:
@@ -375,14 +390,6 @@ def _image_projective_transform_grad(op, grad):
if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
raise TypeError("Invalid dtype %s." % image_or_images.dtype)
- if len(image_or_images.get_shape()) == 2:
- images = image_or_images[None, :, :, None]
- elif len(image_or_images.get_shape()) == 3:
- images = image_or_images[None, :, :, :]
- elif len(image_or_images.get_shape()) == 4:
- images = image_or_images
- else:
- raise TypeError("Images should have rank between 2 and 4")
if len(transform_or_transforms.get_shape()) == 1:
transforms = transform_or_transforms[None]
elif len(transform_or_transforms.get_shape()) == 2:
@@ -395,13 +402,11 @@ def _image_projective_transform_grad(op, grad):
inverse = linalg_ops.matrix_inverse(transforms)
transforms = matrices_to_flat_transforms(inverse)
output = gen_image_ops.image_projective_transform(
- grad, transforms, interpolation=interpolation)
- if len(image_or_images.get_shape()) == 2:
- return [output[0, :, :, 0], None]
- elif len(image_or_images.get_shape()) == 3:
- return [output[0, :, :, :], None]
- else:
- return [output, None]
+ images=grad,
+ transforms=transforms,
+ output_shape=image_or_images.get_shape()[1:3],
+ interpolation=interpolation)
+ return [output, None, None]
def bipartite_match(distance_mat,