aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/image
diff options
context:
space:
mode:
authorGravatar Dan Ringwalt <ringwalt@google.com>2017-05-05 14:28:59 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-05 15:54:38 -0700
commit37e3b71b49495af3873e7916a2ff28e598931b89 (patch)
treee6cc1933dd15962aaf489d3927cd3a72acf8947c /tensorflow/contrib/image
parentfe16da297ceb2cd71e1c7128116e0339ec6789ed (diff)
Add bilinear interpolation to tf.contrib.image.
Change: 155247916
Diffstat (limited to 'tensorflow/contrib/image')
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.cc24
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.h95
-rw-r--r--tensorflow/contrib/image/ops/image_ops.cc2
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/image_ops_test.py49
-rw-r--r--tensorflow/contrib/image/python/ops/image_ops.py17
5 files changed, 155 insertions, 32 deletions
diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc
index 8d50541771..8a97f07732 100644
--- a/tensorflow/contrib/image/kernels/image_ops.cc
+++ b/tensorflow/contrib/image/kernels/image_ops.cc
@@ -43,13 +43,29 @@ template class FillProjectiveTransform<CPUDevice, double>;
typedef Eigen::ThreadPoolDevice CPUDevice;
using functor::FillProjectiveTransform;
+using generator::INTERPOLATION_BILINEAR;
+using generator::INTERPOLATION_NEAREST;
+using generator::Interpolation;
using generator::ProjectiveGenerator;
template <typename Device, typename T>
class ImageProjectiveTransform : public OpKernel {
+ private:
+ Interpolation interpolation_;
+
public:
- explicit ImageProjectiveTransform(OpKernelConstruction* ctx)
- : OpKernel(ctx) {}
+ explicit ImageProjectiveTransform(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string interpolation_str;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("interpolation", &interpolation_str));
+ if (interpolation_str == "NEAREST") {
+ interpolation_ = INTERPOLATION_NEAREST;
+ } else if (interpolation_str == "BILINEAR") {
+ interpolation_ = INTERPOLATION_BILINEAR;
+ } else {
+ LOG(FATAL) << "Invalid interpolation " << interpolation_str
+ << ". Supported types: NEAREST, BILINEAR";
+ }
+ }
void Compute(OpKernelContext* ctx) override {
const Tensor& images_t = ctx->input(0);
@@ -68,8 +84,8 @@ class ImageProjectiveTransform : public OpKernel {
Tensor* output_t;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t));
auto output = output_t->tensor<T, 4>();
- const FillProjectiveTransform<Device, T> functor;
- functor(ctx->eigen_device<Device>(), &output, images, transform);
+ (FillProjectiveTransform<Device, T>(interpolation_))(
+ ctx->eigen_device<Device>(), &output, images, transform);
}
};
diff --git a/tensorflow/contrib/image/kernels/image_ops.h b/tensorflow/contrib/image/kernels/image_ops.h
index 92b908a1c6..692e33fcf3 100644
--- a/tensorflow/contrib/image/kernels/image_ops.h
+++ b/tensorflow/contrib/image/kernels/image_ops.h
@@ -28,6 +28,8 @@ namespace tensorflow {
namespace generator {
+enum Interpolation { INTERPOLATION_NEAREST, INTERPOLATION_BILINEAR };
+
using Eigen::array;
using Eigen::DenseIndex;
@@ -36,20 +38,19 @@ class ProjectiveGenerator {
private:
typename TTypes<T, 4>::ConstTensor input_;
typename TTypes<float>::ConstMatrix transforms_;
+ const Interpolation interpolation_;
public:
static const int kNumParameters = 8;
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
ProjectiveGenerator(typename TTypes<T, 4>::ConstTensor input,
- typename TTypes<float>::ConstMatrix transforms)
- : input_(input), transforms_(transforms) {}
+ typename TTypes<float>::ConstMatrix transforms,
+ const Interpolation interpolation)
+ : input_(input), transforms_(transforms), interpolation_(interpolation) {}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
operator()(const array<DenseIndex, 4>& coords) const {
- array<DenseIndex, 4> input_coords;
- input_coords[0] = coords[0];
-
const int64 output_y = coords[1];
const int64 output_x = coords[2];
const float* transform =
@@ -57,24 +58,73 @@ class ProjectiveGenerator {
? transforms_.data()
: &transforms_.data()[transforms_.dimension(1) * coords[0]];
float projection = transform[6] * output_x + transform[7] * output_y + 1.f;
- const int64 input_x = std::round(
+ const float input_x =
(transform[0] * output_x + transform[1] * output_y + transform[2]) /
- projection);
- const int64 input_y = std::round(
+ projection;
+ const float input_y =
(transform[3] * output_x + transform[4] * output_y + transform[5]) /
- projection);
-
- if (!(0 <= input_y && input_y < input_.dimension(1) && 0 <= input_x &&
- input_x < input_.dimension(2))) {
- // TODO(ringwalt): Add a fill value input.
- return T(0);
+ projection;
+
+ // TODO(ringwalt): Add a fill value input.
+ static const T fill_value = T(0);
+ switch (interpolation_) {
+ case INTERPOLATION_NEAREST:
+ // Switch the order of x and y again for indexing into the image.
+ return nearest_interpolation(coords[0], input_y, input_x, coords[3],
+ fill_value);
+ case INTERPOLATION_BILINEAR:
+ return bilinear_interpolation(coords[0], input_y, input_x, coords[3],
+ fill_value);
}
- input_coords[1] = input_y;
- input_coords[2] = input_x;
+ }
+
+ private:
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
+ nearest_interpolation(const DenseIndex batch, const float y, const float x,
+ const DenseIndex channel, const T fill_value) const {
+ return read_with_fill_value(batch, DenseIndex(std::round(y)),
+ DenseIndex(std::round(x)), channel, fill_value);
+ }
- input_coords[3] = coords[3];
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
+ bilinear_interpolation(const DenseIndex batch, const float y, const float x,
+ const DenseIndex channel, const T fill_value) const {
+ const float y_floor = std::floor(y);
+ const float x_floor = std::floor(x);
+ const float y_ceil = y_floor + 1;
+ const float x_ceil = x_floor + 1;
+ // f(x, y_floor) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_floor)
+ // + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_floor)
+ const float value_yfloor =
+ (x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_floor),
+ DenseIndex(x_floor), channel,
+ fill_value) +
+ (x - x_floor) * read_with_fill_value(batch, DenseIndex(y_floor),
+ DenseIndex(x_ceil), channel,
+ fill_value);
+ // f(x, y_ceil) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_ceil)
+ // + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_ceil)
+ const float value_yceil =
+ (x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_ceil),
+ DenseIndex(x_floor), channel,
+ fill_value) +
+ (x - x_floor) * read_with_fill_value(batch, DenseIndex(y_ceil),
+ DenseIndex(x_ceil), channel,
+ fill_value);
+ // f(x, y) = (y_ceil - y) / (y_ceil - y_floor) * f(x, y_floor)
+ // + (y - y_floor) / (y_ceil - y_floor) * f(x, y_ceil)
+ return T((y_ceil - y) * value_yfloor + (y - y_floor) * value_yceil);
+ }
- return input_(input_coords);
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T read_with_fill_value(
+ const DenseIndex batch, const DenseIndex y, const DenseIndex x,
+ const DenseIndex channel, const T fill_value) const {
+ // batch and channel must be correct, because they are passed unchanged from
+ // the input.
+ return (0 <= y && y < input_.dimension(1) && 0 <= x &&
+ x < input_.dimension(2))
+ ? input_(array<DenseIndex, 4>{batch, y, x, channel})
+ : fill_value;
}
};
@@ -85,6 +135,7 @@ class ProjectiveGenerator {
// some Eigen device code.
namespace functor {
+using generator::Interpolation;
using generator::ProjectiveGenerator;
template <typename Device, typename T>
@@ -92,15 +143,17 @@ struct FillProjectiveTransform {
typedef typename TTypes<T, 4>::Tensor OutputType;
typedef typename TTypes<T, 4>::ConstTensor InputType;
typedef typename TTypes<float, 2>::ConstTensor TransformsType;
+ const Interpolation interpolation_;
- FillProjectiveTransform() {}
+ FillProjectiveTransform(Interpolation interpolation)
+ : interpolation_(interpolation) {}
EIGEN_ALWAYS_INLINE
void operator()(const Device& device, OutputType* output,
const InputType& images,
const TransformsType& transform) const {
- ProjectiveGenerator<Device, T> generator(images, transform);
- output->device(device) = images.generate(generator);
+ output->device(device) = images.generate(
+ ProjectiveGenerator<Device, T>(images, transform, interpolation_));
}
};
diff --git a/tensorflow/contrib/image/ops/image_ops.cc b/tensorflow/contrib/image/ops/image_ops.cc
index 18c16cf1bb..a6d3fa4b64 100644
--- a/tensorflow/contrib/image/ops/image_ops.cc
+++ b/tensorflow/contrib/image/ops/image_ops.cc
@@ -23,13 +23,13 @@ using shape_inference::InferenceContext;
// TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc.
// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
-// TODO(ringwalt): Add an "interpolation" argument with "none", "bilinear", etc.
// TODO(ringwalt): Add an "output_shape" argument. This is sufficient to
// implement "same" and "valid" modes in the Python function.
REGISTER_OP("ImageProjectiveTransform")
.Input("images: dtype")
.Input("transforms: float32")
.Attr("dtype: {uint8, int32, int64, float32, float64}")
+ .Attr("interpolation: string")
.Output("transformed_images: dtype")
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->input(0));
diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
index 33bd30b4e8..5e78b590df 100644
--- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
+++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
@@ -111,6 +111,55 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
[0, 1, 0, 1],
[0, 1, 1, 1]])
+ def test_bilinear(self):
+ with self.test_session():
+ image = constant_op.constant(
+ [[0, 0, 0, 0, 0],
+ [0, 1, 1, 1, 0],
+ [0, 1, 0, 1, 0],
+ [0, 1, 1, 1, 0],
+ [0, 0, 0, 0, 0]],
+ dtypes.float32)
+ # The following result matches:
+ # >>> scipy.ndimage.rotate(image, 45, order=1, reshape=False)
+ # which uses spline interpolation of order 1, equivalent to bilinear
+ # interpolation.
+ self.assertAllClose(
+ image_ops.rotate(image, np.pi / 4.0, interpolation="BILINEAR").eval(),
+ [[0.000, 0.000, 0.343, 0.000, 0.000],
+ [0.000, 0.586, 0.914, 0.586, 0.000],
+ [0.343, 0.914, 0.000, 0.914, 0.343],
+ [0.000, 0.586, 0.914, 0.586, 0.000],
+ [0.000, 0.000, 0.343, 0.000, 0.000]],
+ atol=0.001)
+ self.assertAllClose(
+ image_ops.rotate(image, np.pi / 4.0, interpolation="NEAREST").eval(),
+ [[0, 0, 1, 0, 0],
+ [0, 1, 1, 1, 0],
+ [1, 1, 0, 1, 1],
+ [0, 1, 1, 1, 0],
+ [0, 0, 1, 0, 0]])
+
+ def test_bilinear_uint8(self):
+ with self.test_session():
+ image = constant_op.constant(
+ np.asarray(
+ [[0.0, 0.0, 0.0, 0.0, 0.0],
+ [0.0, 255, 255, 255, 0.0],
+ [0.0, 255, 0.0, 255, 0.0],
+ [0.0, 255, 255, 255, 0.0],
+ [0.0, 0.0, 0.0, 0.0, 0.0]],
+ np.uint8),
+ dtypes.uint8)
+ # == np.rint((expected image above) * 255)
+ self.assertAllEqual(
+ image_ops.rotate(image, np.pi / 4.0, interpolation="BILINEAR").eval(),
+ [[0.0, 0.0, 87., 0.0, 0.0],
+ [0.0, 149, 233, 149, 0.0],
+ [87., 233, 0.0, 233, 87.],
+ [0.0, 149, 233, 149, 0.0],
+ [0.0, 0.0, 87., 0.0, 0.0]])
+
def _test_grad(self, shape_to_test):
with self.test_session():
test_image_shape = shape_to_test
diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py
index 9efdb0d521..0d51d0dee1 100644
--- a/tensorflow/contrib/image/python/ops/image_ops.py
+++ b/tensorflow/contrib/image/python/ops/image_ops.py
@@ -37,7 +37,7 @@ _IMAGE_DTYPES = set(
ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn)
-def rotate(images, angles):
+def rotate(images, angles, interpolation="NEAREST"):
"""Rotate image(s) by the passed angle(s) in radians.
Args:
@@ -46,6 +46,7 @@ def rotate(images, angles):
(num_rows, num_columns) (HW).
angles: A scalar angle to rotate all images by, or (if images has rank 4)
a vector of length num_images, with an angle for each image in the batch.
+ interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
Returns:
Image(s) with the same type and shape as `images`, rotated by the given
@@ -70,7 +71,8 @@ def rotate(images, angles):
image_width = math_ops.cast(array_ops.shape(images)[2], dtypes.float32)[None]
output = transform(
images,
- angles_to_projective_transforms(angles, image_width, image_height))
+ angles_to_projective_transforms(angles, image_height, image_width),
+ interpolation=interpolation)
if len(image_or_images.get_shape()) == 2:
return output[0, :, :, 0]
elif len(image_or_images.get_shape()) == 3:
@@ -120,7 +122,7 @@ def angles_to_projective_transforms(angles, image_height, image_width):
axis=1)
-def transform(images, transforms):
+def transform(images, transforms, interpolation="NEAREST"):
"""Applies the given transform(s) to the image(s).
Args:
@@ -134,6 +136,7 @@ def transform(images, transforms):
`(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,
where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to
the transform mapping input points to output points.
+ interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
Returns:
Image(s) with the same type and shape as `images`, with the given
@@ -163,8 +166,8 @@ def transform(images, transforms):
transforms = transform_or_transforms
else:
raise TypeError("Transforms should have rank 1 or 2.")
- # pylint: disable=protected-access
- output = gen_image_ops.image_projective_transform(images, transforms)
+ output = gen_image_ops.image_projective_transform(
+ images, transforms, interpolation=interpolation.upper())
if len(image_or_images.get_shape()) == 2:
return output[0, :, :, 0]
elif len(image_or_images.get_shape()) == 3:
@@ -220,6 +223,7 @@ def _image_projective_transform_grad(op, grad):
"""Computes the gradient for ImageProjectiveTransform."""
images = op.inputs[0]
transforms = op.inputs[1]
+ interpolation = op.get_attr("interpolation")
image_or_images = ops.convert_to_tensor(images, name="images")
transform_or_transforms = ops.convert_to_tensor(
@@ -246,7 +250,8 @@ def _image_projective_transform_grad(op, grad):
transforms = _flat_transforms_to_matrices(transforms=transforms)
inverse = linalg_ops.matrix_inverse(transforms)
transforms = _transform_matrices_to_flat(inverse)
- output = gen_image_ops.image_projective_transform(grad, transforms)
+ output = gen_image_ops.image_projective_transform(
+ grad, transforms, interpolation=interpolation)
if len(image_or_images.get_shape()) == 2:
return [output[0, :, :, 0], None]
elif len(image_or_images.get_shape()) == 3: