From 9e5fdb83e609701457f6fdc2d153b1f7e83ead6c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Apr 2018 15:56:17 -0700 Subject: Automated g4 rollback of changelist 193564222 PiperOrigin-RevId: 193588935 --- tensorflow/contrib/image/kernels/image_ops.cc | 7 +-- tensorflow/contrib/image/kernels/image_ops.h | 2 +- tensorflow/contrib/image/ops/image_ops.cc | 52 ++-------------------- .../image/python/kernel_tests/image_ops_test.py | 30 ------------- tensorflow/contrib/image/python/ops/image_ops.py | 39 +++++++--------- 5 files changed, 23 insertions(+), 107 deletions(-) (limited to 'tensorflow') diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc index ae4b1ba62a..c2e32da133 100644 --- a/tensorflow/contrib/image/kernels/image_ops.cc +++ b/tensorflow/contrib/image/kernels/image_ops.cc @@ -70,7 +70,6 @@ class ImageProjectiveTransform : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor& images_t = ctx->input(0); const Tensor& transform_t = ctx->input(1); - const Tensor& output_dim = ctx->input(2); OP_REQUIRES(ctx, images_t.shape().dims() == 4, errors::InvalidArgument("Input images must have rank 4")); OP_REQUIRES(ctx, @@ -84,11 +83,7 @@ class ImageProjectiveTransform : public OpKernel { auto images = images_t.tensor(); auto transform = transform_t.matrix(); Tensor* output_t; - // Image is NHWC format. - auto output_shape = images_t.shape(); - output_shape.set_dim(1, output_dim.vec()(0)); - output_shape.set_dim(2, output_dim.vec()(1)); - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output_t)); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t)); auto output = output_t->tensor(); (FillProjectiveTransform(interpolation_))( ctx->eigen_device(), &output, images, transform); diff --git a/tensorflow/contrib/image/kernels/image_ops.h b/tensorflow/contrib/image/kernels/image_ops.h index 2320329b92..ad50133061 100644 --- a/tensorflow/contrib/image/kernels/image_ops.h +++ b/tensorflow/contrib/image/kernels/image_ops.h @@ -161,7 +161,7 @@ struct FillProjectiveTransform { void operator()(const Device& device, OutputType* output, const InputType& images, const TransformsType& transform) const { - output->device(device) = output->generate( + output->device(device) = images.generate( ProjectiveGenerator(images, transform, interpolation_)); } }; diff --git a/tensorflow/contrib/image/ops/image_ops.cc b/tensorflow/contrib/image/ops/image_ops.cc index 4c6d8c0d19..68771b3d05 100644 --- a/tensorflow/contrib/image/ops/image_ops.cc +++ b/tensorflow/contrib/image/ops/image_ops.cc @@ -19,55 +19,9 @@ limitations under the License. namespace tensorflow { -using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; -namespace { - -// Sets output[0] to shape [batch_dim,height,width,channel_dim], where -// height and width come from the size_tensor. -Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim, - int size_input_idx, DimensionHandle channel_dim) { - // Verify shape of size input. - ShapeHandle size; - TF_RETURN_IF_ERROR(c->WithRank(c->input(size_input_idx), 1, &size)); - DimensionHandle unused; - TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 2, &unused)); - - // Get size values from the size tensor. - const Tensor* size_tensor = c->input_tensor(size_input_idx); - DimensionHandle width; - DimensionHandle height; - if (size_tensor == nullptr) { - width = c->UnknownDim(); - height = c->UnknownDim(); - } else { - // TODO(petewarden) - Remove once we have constant evaluation in C++ only. - if (size_tensor->dtype() != DT_INT32) { - return errors::InvalidArgument( - "Bad size input type for SetOutputToSizedImage: Expected DT_INT32 " - "but got ", - DataTypeString(size_tensor->dtype()), " for input #", size_input_idx, - " in ", c->DebugString()); - } - auto vec = size_tensor->vec(); - height = c->MakeDim(vec(0)); - width = c->MakeDim(vec(1)); - } - c->set_output(0, c->MakeShape({batch_dim, height, width, channel_dim})); - return Status::OK(); -} - -Status ResizeShapeFn(InferenceContext* c) { - ShapeHandle input; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); - return SetOutputToSizedImage(c, c->Dim(input, 0), 2 /* size_input_idx */, - c->Dim(input, 3)); -} - -} // namespace - // TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc. // TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0). // TODO(ringwalt): Add an "output_shape" argument. This is sufficient to @@ -75,11 +29,13 @@ Status ResizeShapeFn(InferenceContext* c) { REGISTER_OP("ImageProjectiveTransform") .Input("images: dtype") .Input("transforms: float32") - .Input("output_shape: int32") .Attr("dtype: {uint8, int32, int64, float32, float64}") .Attr("interpolation: string") .Output("transformed_images: dtype") - .SetShapeFn(ResizeShapeFn) + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }) .Doc(R"doc( Applies the given transform to each of the images. diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py index c0151d320f..b50177ae56 100644 --- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py @@ -195,40 +195,10 @@ class ImageOpsTest(test_util.TensorFlowTestCase): x_init_value=test_image) self.assertLess(left_err, 1e-10) - def _test_grad_different_shape(self, input_shape, output_shape): - with self.test_session(): - test_image_shape = input_shape - test_image = np.random.randn(*test_image_shape) - test_image_tensor = constant_op.constant( - test_image, shape=test_image_shape) - test_transform = image_ops.angles_to_projective_transforms( - np.pi / 2, 4, 4) - - if len(output_shape) == 2: - resize_shape = output_shape - elif len(output_shape) == 3: - resize_shape = output_shape[0:2] - elif len(output_shape) == 4: - resize_shape = output_shape[1:3] - output = image_ops.transform( - images=test_image_tensor, - transforms=test_transform, - output_shape=resize_shape) - left_err = gradient_checker.compute_gradient_error( - test_image_tensor, - test_image_shape, - output, - output_shape, - x_init_value=test_image) - self.assertLess(left_err, 1e-10) - def test_grad(self): self._test_grad([16, 16]) self._test_grad([4, 12, 12]) self._test_grad([3, 4, 12, 12]) - self._test_grad_different_shape([16, 16], [8, 8]) - self._test_grad_different_shape([4, 12, 3], [8, 24, 3]) - self._test_grad_different_shape([3, 4, 12, 3], [3, 8, 24, 3]) class BipartiteMatchTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index 0cb7bdc75d..c139ae89d8 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -212,11 +212,7 @@ def translations_to_projective_transforms(translations, name=None): axis=1) -def transform(images, - transforms, - output_shape=None, - interpolation="NEAREST", - name=None): +def transform(images, transforms, interpolation="NEAREST", name=None): """Applies the given transform(s) to the image(s). Args: @@ -232,10 +228,7 @@ def transform(images, where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to the transform mapping input points to output points. Note that gradients are not backpropagated into transformation parameters. - output_shape: Output dimesion after the transform, [height, width]. - If None, output is the same size as input image. interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". - name: The name of the op. Returns: Image(s) with the same type and shape as `images`, with the given @@ -262,14 +255,6 @@ def transform(images, else: raise TypeError("Images should have rank between 2 and 4.") - if output_shape is None: - output_shape = images.get_shape()[1:3] - elif len(output_shape) != 2: - raise TypeError( - "output_shape must either be None or a vector of 2 elements.") - output_shape = ops.convert_to_tensor( - output_shape, name="output_shape", dtype=dtypes.int32) - if len(transform_or_transforms.get_shape()) == 1: transforms = transform_or_transforms[None] elif transform_or_transforms.get_shape().ndims is None: @@ -280,7 +265,7 @@ def transform(images, else: raise TypeError("Transforms should have rank 1 or 2.") output = gen_image_ops.image_projective_transform( - images, transforms, output_shape, interpolation=interpolation.upper()) + images, transforms, interpolation=interpolation.upper()) if len(image_or_images.get_shape()) == 2: return output[0, :, :, 0] elif len(image_or_images.get_shape()) == 3: @@ -390,6 +375,14 @@ def _image_projective_transform_grad(op, grad): if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: raise TypeError("Invalid dtype %s." % image_or_images.dtype) + if len(image_or_images.get_shape()) == 2: + images = image_or_images[None, :, :, None] + elif len(image_or_images.get_shape()) == 3: + images = image_or_images[None, :, :, :] + elif len(image_or_images.get_shape()) == 4: + images = image_or_images + else: + raise TypeError("Images should have rank between 2 and 4") if len(transform_or_transforms.get_shape()) == 1: transforms = transform_or_transforms[None] elif len(transform_or_transforms.get_shape()) == 2: @@ -402,11 +395,13 @@ def _image_projective_transform_grad(op, grad): inverse = linalg_ops.matrix_inverse(transforms) transforms = matrices_to_flat_transforms(inverse) output = gen_image_ops.image_projective_transform( - images=grad, - transforms=transforms, - output_shape=image_or_images.get_shape()[1:3], - interpolation=interpolation) - return [output, None, None] + grad, transforms, interpolation=interpolation) + if len(image_or_images.get_shape()) == 2: + return [output[0, :, :, 0], None] + elif len(image_or_images.get_shape()) == 3: + return [output[0, :, :, :], None] + else: + return [output, None] def bipartite_match(distance_mat, -- cgit v1.2.3