diff options
author | 2017-05-05 14:28:59 -0800 | |
---|---|---|
committer | 2017-05-05 15:54:38 -0700 | |
commit | 37e3b71b49495af3873e7916a2ff28e598931b89 (patch) | |
tree | e6cc1933dd15962aaf489d3927cd3a72acf8947c /tensorflow/contrib/image | |
parent | fe16da297ceb2cd71e1c7128116e0339ec6789ed (diff) |
Add bilinear interpolation to tf.contrib.image.
Change: 155247916
Diffstat (limited to 'tensorflow/contrib/image')
-rw-r--r-- | tensorflow/contrib/image/kernels/image_ops.cc | 24 | ||||
-rw-r--r-- | tensorflow/contrib/image/kernels/image_ops.h | 95 | ||||
-rw-r--r-- | tensorflow/contrib/image/ops/image_ops.cc | 2 | ||||
-rw-r--r-- | tensorflow/contrib/image/python/kernel_tests/image_ops_test.py | 49 | ||||
-rw-r--r-- | tensorflow/contrib/image/python/ops/image_ops.py | 17 |
5 files changed, 155 insertions, 32 deletions
diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc index 8d50541771..8a97f07732 100644 --- a/tensorflow/contrib/image/kernels/image_ops.cc +++ b/tensorflow/contrib/image/kernels/image_ops.cc @@ -43,13 +43,29 @@ template class FillProjectiveTransform<CPUDevice, double>; typedef Eigen::ThreadPoolDevice CPUDevice; using functor::FillProjectiveTransform; +using generator::INTERPOLATION_BILINEAR; +using generator::INTERPOLATION_NEAREST; +using generator::Interpolation; using generator::ProjectiveGenerator; template <typename Device, typename T> class ImageProjectiveTransform : public OpKernel { + private: + Interpolation interpolation_; + public: - explicit ImageProjectiveTransform(OpKernelConstruction* ctx) - : OpKernel(ctx) {} + explicit ImageProjectiveTransform(OpKernelConstruction* ctx) : OpKernel(ctx) { + string interpolation_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("interpolation", &interpolation_str)); + if (interpolation_str == "NEAREST") { + interpolation_ = INTERPOLATION_NEAREST; + } else if (interpolation_str == "BILINEAR") { + interpolation_ = INTERPOLATION_BILINEAR; + } else { + LOG(FATAL) << "Invalid interpolation " << interpolation_str + << ". Supported types: NEAREST, BILINEAR"; + } + } void Compute(OpKernelContext* ctx) override { const Tensor& images_t = ctx->input(0); @@ -68,8 +84,8 @@ class ImageProjectiveTransform : public OpKernel { Tensor* output_t; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t)); auto output = output_t->tensor<T, 4>(); - const FillProjectiveTransform<Device, T> functor; - functor(ctx->eigen_device<Device>(), &output, images, transform); + (FillProjectiveTransform<Device, T>(interpolation_))( + ctx->eigen_device<Device>(), &output, images, transform); } }; diff --git a/tensorflow/contrib/image/kernels/image_ops.h b/tensorflow/contrib/image/kernels/image_ops.h index 92b908a1c6..692e33fcf3 100644 --- a/tensorflow/contrib/image/kernels/image_ops.h +++ b/tensorflow/contrib/image/kernels/image_ops.h @@ -28,6 +28,8 @@ namespace tensorflow { namespace generator { +enum Interpolation { INTERPOLATION_NEAREST, INTERPOLATION_BILINEAR }; + using Eigen::array; using Eigen::DenseIndex; @@ -36,20 +38,19 @@ class ProjectiveGenerator { private: typename TTypes<T, 4>::ConstTensor input_; typename TTypes<float>::ConstMatrix transforms_; + const Interpolation interpolation_; public: static const int kNumParameters = 8; EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ProjectiveGenerator(typename TTypes<T, 4>::ConstTensor input, - typename TTypes<float>::ConstMatrix transforms) - : input_(input), transforms_(transforms) {} + typename TTypes<float>::ConstMatrix transforms, + const Interpolation interpolation) + : input_(input), transforms_(transforms), interpolation_(interpolation) {} EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T operator()(const array<DenseIndex, 4>& coords) const { - array<DenseIndex, 4> input_coords; - input_coords[0] = coords[0]; - const int64 output_y = coords[1]; const int64 output_x = coords[2]; const float* transform = @@ -57,24 +58,73 @@ class ProjectiveGenerator { ? transforms_.data() : &transforms_.data()[transforms_.dimension(1) * coords[0]]; float projection = transform[6] * output_x + transform[7] * output_y + 1.f; - const int64 input_x = std::round( + const float input_x = (transform[0] * output_x + transform[1] * output_y + transform[2]) / - projection); - const int64 input_y = std::round( + projection; + const float input_y = (transform[3] * output_x + transform[4] * output_y + transform[5]) / - projection); - - if (!(0 <= input_y && input_y < input_.dimension(1) && 0 <= input_x && - input_x < input_.dimension(2))) { - // TODO(ringwalt): Add a fill value input. - return T(0); + projection; + + // TODO(ringwalt): Add a fill value input. + static const T fill_value = T(0); + switch (interpolation_) { + case INTERPOLATION_NEAREST: + // Switch the order of x and y again for indexing into the image. + return nearest_interpolation(coords[0], input_y, input_x, coords[3], + fill_value); + case INTERPOLATION_BILINEAR: + return bilinear_interpolation(coords[0], input_y, input_x, coords[3], + fill_value); } - input_coords[1] = input_y; - input_coords[2] = input_x; + } + + private: + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T + nearest_interpolation(const DenseIndex batch, const float y, const float x, + const DenseIndex channel, const T fill_value) const { + return read_with_fill_value(batch, DenseIndex(std::round(y)), + DenseIndex(std::round(x)), channel, fill_value); + } - input_coords[3] = coords[3]; + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T + bilinear_interpolation(const DenseIndex batch, const float y, const float x, + const DenseIndex channel, const T fill_value) const { + const float y_floor = std::floor(y); + const float x_floor = std::floor(x); + const float y_ceil = y_floor + 1; + const float x_ceil = x_floor + 1; + // f(x, y_floor) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_floor) + // + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_floor) + const float value_yfloor = + (x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_floor), + DenseIndex(x_floor), channel, + fill_value) + + (x - x_floor) * read_with_fill_value(batch, DenseIndex(y_floor), + DenseIndex(x_ceil), channel, + fill_value); + // f(x, y_ceil) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_ceil) + // + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_ceil) + const float value_yceil = + (x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_ceil), + DenseIndex(x_floor), channel, + fill_value) + + (x - x_floor) * read_with_fill_value(batch, DenseIndex(y_ceil), + DenseIndex(x_ceil), channel, + fill_value); + // f(x, y) = (y_ceil - y) / (y_ceil - y_floor) * f(x, y_floor) + // + (y - y_floor) / (y_ceil - y_floor) * f(x, y_ceil) + return T((y_ceil - y) * value_yfloor + (y - y_floor) * value_yceil); + } - return input_(input_coords); + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T read_with_fill_value( + const DenseIndex batch, const DenseIndex y, const DenseIndex x, + const DenseIndex channel, const T fill_value) const { + // batch and channel must be correct, because they are passed unchanged from + // the input. + return (0 <= y && y < input_.dimension(1) && 0 <= x && + x < input_.dimension(2)) + ? input_(array<DenseIndex, 4>{batch, y, x, channel}) + : fill_value; } }; @@ -85,6 +135,7 @@ class ProjectiveGenerator { // some Eigen device code. namespace functor { +using generator::Interpolation; using generator::ProjectiveGenerator; template <typename Device, typename T> @@ -92,15 +143,17 @@ struct FillProjectiveTransform { typedef typename TTypes<T, 4>::Tensor OutputType; typedef typename TTypes<T, 4>::ConstTensor InputType; typedef typename TTypes<float, 2>::ConstTensor TransformsType; + const Interpolation interpolation_; - FillProjectiveTransform() {} + FillProjectiveTransform(Interpolation interpolation) + : interpolation_(interpolation) {} EIGEN_ALWAYS_INLINE void operator()(const Device& device, OutputType* output, const InputType& images, const TransformsType& transform) const { - ProjectiveGenerator<Device, T> generator(images, transform); - output->device(device) = images.generate(generator); + output->device(device) = images.generate( + ProjectiveGenerator<Device, T>(images, transform, interpolation_)); } }; diff --git a/tensorflow/contrib/image/ops/image_ops.cc b/tensorflow/contrib/image/ops/image_ops.cc index 18c16cf1bb..a6d3fa4b64 100644 --- a/tensorflow/contrib/image/ops/image_ops.cc +++ b/tensorflow/contrib/image/ops/image_ops.cc @@ -23,13 +23,13 @@ using shape_inference::InferenceContext; // TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc. // TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0). -// TODO(ringwalt): Add an "interpolation" argument with "none", "bilinear", etc. // TODO(ringwalt): Add an "output_shape" argument. This is sufficient to // implement "same" and "valid" modes in the Python function. REGISTER_OP("ImageProjectiveTransform") .Input("images: dtype") .Input("transforms: float32") .Attr("dtype: {uint8, int32, int64, float32, float64}") + .Attr("interpolation: string") .Output("transformed_images: dtype") .SetShapeFn([](InferenceContext* c) { c->set_output(0, c->input(0)); diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py index 33bd30b4e8..5e78b590df 100644 --- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py @@ -111,6 +111,55 @@ class ImageOpsTest(test_util.TensorFlowTestCase): [0, 1, 0, 1], [0, 1, 1, 1]]) + def test_bilinear(self): + with self.test_session(): + image = constant_op.constant( + [[0, 0, 0, 0, 0], + [0, 1, 1, 1, 0], + [0, 1, 0, 1, 0], + [0, 1, 1, 1, 0], + [0, 0, 0, 0, 0]], + dtypes.float32) + # The following result matches: + # >>> scipy.ndimage.rotate(image, 45, order=1, reshape=False) + # which uses spline interpolation of order 1, equivalent to bilinear + # interpolation. + self.assertAllClose( + image_ops.rotate(image, np.pi / 4.0, interpolation="BILINEAR").eval(), + [[0.000, 0.000, 0.343, 0.000, 0.000], + [0.000, 0.586, 0.914, 0.586, 0.000], + [0.343, 0.914, 0.000, 0.914, 0.343], + [0.000, 0.586, 0.914, 0.586, 0.000], + [0.000, 0.000, 0.343, 0.000, 0.000]], + atol=0.001) + self.assertAllClose( + image_ops.rotate(image, np.pi / 4.0, interpolation="NEAREST").eval(), + [[0, 0, 1, 0, 0], + [0, 1, 1, 1, 0], + [1, 1, 0, 1, 1], + [0, 1, 1, 1, 0], + [0, 0, 1, 0, 0]]) + + def test_bilinear_uint8(self): + with self.test_session(): + image = constant_op.constant( + np.asarray( + [[0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 255, 255, 255, 0.0], + [0.0, 255, 0.0, 255, 0.0], + [0.0, 255, 255, 255, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0]], + np.uint8), + dtypes.uint8) + # == np.rint((expected image above) * 255) + self.assertAllEqual( + image_ops.rotate(image, np.pi / 4.0, interpolation="BILINEAR").eval(), + [[0.0, 0.0, 87., 0.0, 0.0], + [0.0, 149, 233, 149, 0.0], + [87., 233, 0.0, 233, 87.], + [0.0, 149, 233, 149, 0.0], + [0.0, 0.0, 87., 0.0, 0.0]]) + def _test_grad(self, shape_to_test): with self.test_session(): test_image_shape = shape_to_test diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index 9efdb0d521..0d51d0dee1 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -37,7 +37,7 @@ _IMAGE_DTYPES = set( ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn) -def rotate(images, angles): +def rotate(images, angles, interpolation="NEAREST"): """Rotate image(s) by the passed angle(s) in radians. Args: @@ -46,6 +46,7 @@ def rotate(images, angles): (num_rows, num_columns) (HW). angles: A scalar angle to rotate all images by, or (if images has rank 4) a vector of length num_images, with an angle for each image in the batch. + interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". Returns: Image(s) with the same type and shape as `images`, rotated by the given @@ -70,7 +71,8 @@ def rotate(images, angles): image_width = math_ops.cast(array_ops.shape(images)[2], dtypes.float32)[None] output = transform( images, - angles_to_projective_transforms(angles, image_width, image_height)) + angles_to_projective_transforms(angles, image_height, image_width), + interpolation=interpolation) if len(image_or_images.get_shape()) == 2: return output[0, :, :, 0] elif len(image_or_images.get_shape()) == 3: @@ -120,7 +122,7 @@ def angles_to_projective_transforms(angles, image_height, image_width): axis=1) -def transform(images, transforms): +def transform(images, transforms, interpolation="NEAREST"): """Applies the given transform(s) to the image(s). Args: @@ -134,6 +136,7 @@ def transform(images, transforms): `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to the transform mapping input points to output points. + interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". Returns: Image(s) with the same type and shape as `images`, with the given @@ -163,8 +166,8 @@ def transform(images, transforms): transforms = transform_or_transforms else: raise TypeError("Transforms should have rank 1 or 2.") - # pylint: disable=protected-access - output = gen_image_ops.image_projective_transform(images, transforms) + output = gen_image_ops.image_projective_transform( + images, transforms, interpolation=interpolation.upper()) if len(image_or_images.get_shape()) == 2: return output[0, :, :, 0] elif len(image_or_images.get_shape()) == 3: @@ -220,6 +223,7 @@ def _image_projective_transform_grad(op, grad): """Computes the gradient for ImageProjectiveTransform.""" images = op.inputs[0] transforms = op.inputs[1] + interpolation = op.get_attr("interpolation") image_or_images = ops.convert_to_tensor(images, name="images") transform_or_transforms = ops.convert_to_tensor( @@ -246,7 +250,8 @@ def _image_projective_transform_grad(op, grad): transforms = _flat_transforms_to_matrices(transforms=transforms) inverse = linalg_ops.matrix_inverse(transforms) transforms = _transform_matrices_to_flat(inverse) - output = gen_image_ops.image_projective_transform(grad, transforms) + output = gen_image_ops.image_projective_transform( + grad, transforms, interpolation=interpolation) if len(image_or_images.get_shape()) == 2: return [output[0, :, :, 0], None] elif len(image_or_images.get_shape()) == 3: |