aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/image
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-19 15:56:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-19 16:01:23 -0700
commit9e5fdb83e609701457f6fdc2d153b1f7e83ead6c (patch)
tree02e87b4c3cc780d663d18b15b12a6d1be6b38e98 /tensorflow/contrib/image
parentd5c32f4ccc85ad0d13f3a1f83e063211504cf976 (diff)
Automated g4 rollback of changelist 193564222
PiperOrigin-RevId: 193588935
Diffstat (limited to 'tensorflow/contrib/image')
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.cc7
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.h2
-rw-r--r--tensorflow/contrib/image/ops/image_ops.cc52
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/image_ops_test.py30
-rw-r--r--tensorflow/contrib/image/python/ops/image_ops.py39
5 files changed, 23 insertions, 107 deletions
diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc
index ae4b1ba62a..c2e32da133 100644
--- a/tensorflow/contrib/image/kernels/image_ops.cc
+++ b/tensorflow/contrib/image/kernels/image_ops.cc
@@ -70,7 +70,6 @@ class ImageProjectiveTransform : public OpKernel {
void Compute(OpKernelContext* ctx) override {
const Tensor& images_t = ctx->input(0);
const Tensor& transform_t = ctx->input(1);
- const Tensor& output_dim = ctx->input(2);
OP_REQUIRES(ctx, images_t.shape().dims() == 4,
errors::InvalidArgument("Input images must have rank 4"));
OP_REQUIRES(ctx,
@@ -84,11 +83,7 @@ class ImageProjectiveTransform : public OpKernel {
auto images = images_t.tensor<T, 4>();
auto transform = transform_t.matrix<float>();
Tensor* output_t;
- // Image is NHWC format.
- auto output_shape = images_t.shape();
- output_shape.set_dim(1, output_dim.vec<int>()(0));
- output_shape.set_dim(2, output_dim.vec<int>()(1));
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output_t));
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t));
auto output = output_t->tensor<T, 4>();
(FillProjectiveTransform<Device, T>(interpolation_))(
ctx->eigen_device<Device>(), &output, images, transform);
diff --git a/tensorflow/contrib/image/kernels/image_ops.h b/tensorflow/contrib/image/kernels/image_ops.h
index 2320329b92..ad50133061 100644
--- a/tensorflow/contrib/image/kernels/image_ops.h
+++ b/tensorflow/contrib/image/kernels/image_ops.h
@@ -161,7 +161,7 @@ struct FillProjectiveTransform {
void operator()(const Device& device, OutputType* output,
const InputType& images,
const TransformsType& transform) const {
- output->device(device) = output->generate(
+ output->device(device) = images.generate(
ProjectiveGenerator<Device, T>(images, transform, interpolation_));
}
};
diff --git a/tensorflow/contrib/image/ops/image_ops.cc b/tensorflow/contrib/image/ops/image_ops.cc
index 4c6d8c0d19..68771b3d05 100644
--- a/tensorflow/contrib/image/ops/image_ops.cc
+++ b/tensorflow/contrib/image/ops/image_ops.cc
@@ -19,55 +19,9 @@ limitations under the License.
namespace tensorflow {
-using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
-namespace {
-
-// Sets output[0] to shape [batch_dim,height,width,channel_dim], where
-// height and width come from the size_tensor.
-Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim,
- int size_input_idx, DimensionHandle channel_dim) {
- // Verify shape of size input.
- ShapeHandle size;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(size_input_idx), 1, &size));
- DimensionHandle unused;
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 2, &unused));
-
- // Get size values from the size tensor.
- const Tensor* size_tensor = c->input_tensor(size_input_idx);
- DimensionHandle width;
- DimensionHandle height;
- if (size_tensor == nullptr) {
- width = c->UnknownDim();
- height = c->UnknownDim();
- } else {
- // TODO(petewarden) - Remove once we have constant evaluation in C++ only.
- if (size_tensor->dtype() != DT_INT32) {
- return errors::InvalidArgument(
- "Bad size input type for SetOutputToSizedImage: Expected DT_INT32 "
- "but got ",
- DataTypeString(size_tensor->dtype()), " for input #", size_input_idx,
- " in ", c->DebugString());
- }
- auto vec = size_tensor->vec<int32>();
- height = c->MakeDim(vec(0));
- width = c->MakeDim(vec(1));
- }
- c->set_output(0, c->MakeShape({batch_dim, height, width, channel_dim}));
- return Status::OK();
-}
-
-Status ResizeShapeFn(InferenceContext* c) {
- ShapeHandle input;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
- return SetOutputToSizedImage(c, c->Dim(input, 0), 2 /* size_input_idx */,
- c->Dim(input, 3));
-}
-
-} // namespace
-
// TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc.
// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
// TODO(ringwalt): Add an "output_shape" argument. This is sufficient to
@@ -75,11 +29,13 @@ Status ResizeShapeFn(InferenceContext* c) {
REGISTER_OP("ImageProjectiveTransform")
.Input("images: dtype")
.Input("transforms: float32")
- .Input("output_shape: int32")
.Attr("dtype: {uint8, int32, int64, float32, float64}")
.Attr("interpolation: string")
.Output("transformed_images: dtype")
- .SetShapeFn(ResizeShapeFn)
+ .SetShapeFn([](InferenceContext* c) {
+ c->set_output(0, c->input(0));
+ return Status::OK();
+ })
.Doc(R"doc(
Applies the given transform to each of the images.
diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
index c0151d320f..b50177ae56 100644
--- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
+++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
@@ -195,40 +195,10 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
x_init_value=test_image)
self.assertLess(left_err, 1e-10)
- def _test_grad_different_shape(self, input_shape, output_shape):
- with self.test_session():
- test_image_shape = input_shape
- test_image = np.random.randn(*test_image_shape)
- test_image_tensor = constant_op.constant(
- test_image, shape=test_image_shape)
- test_transform = image_ops.angles_to_projective_transforms(
- np.pi / 2, 4, 4)
-
- if len(output_shape) == 2:
- resize_shape = output_shape
- elif len(output_shape) == 3:
- resize_shape = output_shape[0:2]
- elif len(output_shape) == 4:
- resize_shape = output_shape[1:3]
- output = image_ops.transform(
- images=test_image_tensor,
- transforms=test_transform,
- output_shape=resize_shape)
- left_err = gradient_checker.compute_gradient_error(
- test_image_tensor,
- test_image_shape,
- output,
- output_shape,
- x_init_value=test_image)
- self.assertLess(left_err, 1e-10)
-
def test_grad(self):
self._test_grad([16, 16])
self._test_grad([4, 12, 12])
self._test_grad([3, 4, 12, 12])
- self._test_grad_different_shape([16, 16], [8, 8])
- self._test_grad_different_shape([4, 12, 3], [8, 24, 3])
- self._test_grad_different_shape([3, 4, 12, 3], [3, 8, 24, 3])
class BipartiteMatchTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py
index 0cb7bdc75d..c139ae89d8 100644
--- a/tensorflow/contrib/image/python/ops/image_ops.py
+++ b/tensorflow/contrib/image/python/ops/image_ops.py
@@ -212,11 +212,7 @@ def translations_to_projective_transforms(translations, name=None):
axis=1)
-def transform(images,
- transforms,
- output_shape=None,
- interpolation="NEAREST",
- name=None):
+def transform(images, transforms, interpolation="NEAREST", name=None):
"""Applies the given transform(s) to the image(s).
Args:
@@ -232,10 +228,7 @@ def transform(images,
where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to
the transform mapping input points to output points. Note that gradients
are not backpropagated into transformation parameters.
- output_shape: Output dimesion after the transform, [height, width].
- If None, output is the same size as input image.
interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
- name: The name of the op.
Returns:
Image(s) with the same type and shape as `images`, with the given
@@ -262,14 +255,6 @@ def transform(images,
else:
raise TypeError("Images should have rank between 2 and 4.")
- if output_shape is None:
- output_shape = images.get_shape()[1:3]
- elif len(output_shape) != 2:
- raise TypeError(
- "output_shape must either be None or a vector of 2 elements.")
- output_shape = ops.convert_to_tensor(
- output_shape, name="output_shape", dtype=dtypes.int32)
-
if len(transform_or_transforms.get_shape()) == 1:
transforms = transform_or_transforms[None]
elif transform_or_transforms.get_shape().ndims is None:
@@ -280,7 +265,7 @@ def transform(images,
else:
raise TypeError("Transforms should have rank 1 or 2.")
output = gen_image_ops.image_projective_transform(
- images, transforms, output_shape, interpolation=interpolation.upper())
+ images, transforms, interpolation=interpolation.upper())
if len(image_or_images.get_shape()) == 2:
return output[0, :, :, 0]
elif len(image_or_images.get_shape()) == 3:
@@ -390,6 +375,14 @@ def _image_projective_transform_grad(op, grad):
if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
raise TypeError("Invalid dtype %s." % image_or_images.dtype)
+ if len(image_or_images.get_shape()) == 2:
+ images = image_or_images[None, :, :, None]
+ elif len(image_or_images.get_shape()) == 3:
+ images = image_or_images[None, :, :, :]
+ elif len(image_or_images.get_shape()) == 4:
+ images = image_or_images
+ else:
+ raise TypeError("Images should have rank between 2 and 4")
if len(transform_or_transforms.get_shape()) == 1:
transforms = transform_or_transforms[None]
elif len(transform_or_transforms.get_shape()) == 2:
@@ -402,11 +395,13 @@ def _image_projective_transform_grad(op, grad):
inverse = linalg_ops.matrix_inverse(transforms)
transforms = matrices_to_flat_transforms(inverse)
output = gen_image_ops.image_projective_transform(
- images=grad,
- transforms=transforms,
- output_shape=image_or_images.get_shape()[1:3],
- interpolation=interpolation)
- return [output, None, None]
+ grad, transforms, interpolation=interpolation)
+ if len(image_or_images.get_shape()) == 2:
+ return [output[0, :, :, 0], None]
+ elif len(image_or_images.get_shape()) == 3:
+ return [output[0, :, :, :], None]
+ else:
+ return [output, None]
def bipartite_match(distance_mat,