aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/image
diff options
context:
space:
mode:
authorGravatar Dan Ringwalt <ringwalt@google.com>2018-01-17 11:33:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-17 11:39:36 -0800
commita41ab15aeea526355d807fcf35e057ece0e35bc4 (patch)
treefa58e8c5ff312bbe2181129d0633c1e91b767d38 /tensorflow/contrib/image
parent9104700715ac5dabd92c277693ee6bc8cd46bdd9 (diff)
Add tf.contrib.image.connected_components.
Comparable to scipy.ndimage.measurements.label. PiperOrigin-RevId: 182244926
Diffstat (limited to 'tensorflow/contrib/image')
-rwxr-xr-xtensorflow/contrib/image/BUILD21
-rwxr-xr-xtensorflow/contrib/image/__init__.py14
-rw-r--r--tensorflow/contrib/image/kernels/segmentation_ops.cc139
-rw-r--r--tensorflow/contrib/image/kernels/segmentation_ops.h303
-rw-r--r--tensorflow/contrib/image/ops/image_ops.cc30
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/segmentation_test.py189
-rw-r--r--tensorflow/contrib/image/python/ops/image_ops.py70
7 files changed, 765 insertions, 1 deletions
diff --git a/tensorflow/contrib/image/BUILD b/tensorflow/contrib/image/BUILD
index ce2b279e51..3ff02e085e 100755
--- a/tensorflow/contrib/image/BUILD
+++ b/tensorflow/contrib/image/BUILD
@@ -14,6 +14,7 @@ load(
"tf_gen_op_libs",
"tf_gen_op_wrapper_py",
"tf_kernel_library",
+ "tf_py_test",
)
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
@@ -24,6 +25,8 @@ tf_custom_op_library(
"kernels/bipartite_match_op.cc",
"kernels/image_ops.cc",
"kernels/image_ops.h",
+ "kernels/segmentation_ops.cc",
+ "kernels/segmentation_ops.h",
"ops/image_ops.cc",
],
gpu_srcs = [
@@ -38,6 +41,8 @@ tf_kernel_library(
"kernels/bipartite_match_op.cc",
"kernels/image_ops.cc",
"kernels/image_ops.h",
+ "kernels/segmentation_ops.cc",
+ "kernels/segmentation_ops.h",
],
gpu_srcs = [
"kernels/image_ops_gpu.cu.cc",
@@ -78,6 +83,7 @@ tf_custom_op_py_library(
"//tensorflow/python:array_ops",
"//tensorflow/python:common_shapes",
"//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
@@ -188,6 +194,21 @@ cuda_py_test(
],
)
+tf_py_test(
+ name = "segmentation_test",
+ size = "medium",
+ srcs = ["python/kernel_tests/segmentation_test.py"],
+ additional_deps = [
+ ":distort_image_py",
+ ":image_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
tf_custom_op_library(
name = "python/ops/_single_image_random_dot_stereograms.so",
srcs = [
diff --git a/tensorflow/contrib/image/__init__.py b/tensorflow/contrib/image/__init__.py
index d030dffade..cc8ed117ba 100755
--- a/tensorflow/contrib/image/__init__.py
+++ b/tensorflow/contrib/image/__init__.py
@@ -20,6 +20,8 @@ This module provides functions for image manipulation; currently, chrominance
transformas (including changing saturation and hue) in YIQ space and
projective transforms (including rotation) are supported.
+## Image Transformation `Ops`
+
@@angles_to_projective_transforms
@@compose_transforms
@@adjust_yiq_hsv
@@ -28,19 +30,29 @@ projective transforms (including rotation) are supported.
@@transform
@@translate
@@translations_to_projective_transforms
+
+## Image Segmentation `Ops`
+
+@@connected_components
+
+## Matching `Ops`
+
@@bipartite_match
+
+## Random Dot Stereogram `Ops`
+
@@single_image_random_dot_stereograms
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-# pylint: disable=line-too-long
from tensorflow.contrib.image.python.ops.distort_image_ops import adjust_hsv_in_yiq
from tensorflow.contrib.image.python.ops.distort_image_ops import random_hsv_in_yiq
from tensorflow.contrib.image.python.ops.image_ops import angles_to_projective_transforms
from tensorflow.contrib.image.python.ops.image_ops import compose_transforms
+from tensorflow.contrib.image.python.ops.image_ops import connected_components
from tensorflow.contrib.image.python.ops.image_ops import rotate
from tensorflow.contrib.image.python.ops.image_ops import transform
from tensorflow.contrib.image.python.ops.image_ops import translate
diff --git a/tensorflow/contrib/image/kernels/segmentation_ops.cc b/tensorflow/contrib/image/kernels/segmentation_ops.cc
new file mode 100644
index 0000000000..fe8bf6e21c
--- /dev/null
+++ b/tensorflow/contrib/image/kernels/segmentation_ops.cc
@@ -0,0 +1,139 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs for ImageConnectedComponents in ../ops/image_ops.cc, and description
+// of the algorithm in segmentation_ops.h.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/contrib/image/kernels/segmentation_ops.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+using tensorflow::functor::BlockedImageUnionFindFunctor;
+using tensorflow::functor::FindRootFunctor;
+using tensorflow::functor::ImageConnectedComponentsFunctor;
+using tensorflow::functor::TensorRangeFunctor;
+
+using OutputType = typename BlockedImageUnionFindFunctor<bool>::OutputType;
+
+// Computes connected components on batches of 2D images.
+template <typename Device, typename T>
+class ImageConnectedComponents : public OpKernel {
+ public:
+ explicit ImageConnectedComponents(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& images_t = ctx->input(0);
+ OP_REQUIRES(ctx, images_t.shape().dims() == 3,
+ errors::InvalidArgument("Input images must have rank 3"));
+ Tensor forest_t, rank_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(tensorflow::DT_INT64,
+ images_t.shape(), &forest_t));
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(tensorflow::DT_INT64,
+ images_t.shape(), &rank_t));
+ Tensor* output_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t));
+
+ // Fill forest with values from 0 to n - 1, so that each node points to
+ // itself.
+ TensorRangeFunctor<Device>()(ctx->eigen_device<Device>(),
+ forest_t.flat<OutputType>());
+ auto rank = rank_t.tensor<OutputType, 3>();
+ rank.device(ctx->eigen_device<Device>()) = rank.constant(OutputType(0));
+
+ const auto images = images_t.tensor<T, 3>();
+ auto forest = forest_t.tensor<OutputType, 3>();
+ ImageConnectedComponentsFunctor<Device, T>()(
+ ctx, output_t->flat<OutputType>(), images, forest, rank);
+ }
+};
+
+using CPUDevice = Eigen::ThreadPoolDevice;
+
+namespace functor {
+
+// Connected components CPU implementation. See `segmentation_ops.h` for a
+// description of the algorithm.
+template <typename T>
+struct ImageConnectedComponentsFunctor<CPUDevice, T> {
+ void operator()(OpKernelContext* ctx,
+ typename TTypes<OutputType>::Flat output,
+ typename TTypes<T, 3>::ConstTensor images,
+ typename TTypes<OutputType, 3>::Tensor forest,
+ typename TTypes<OutputType, 3>::Tensor rank) {
+ const int64 num_images = images.dimension(0),
+ num_rows = images.dimension(1), num_cols = images.dimension(2),
+ num_elements = images.size();
+ // Bail out early for an empty image--no work to do.
+ if (num_elements == 0) {
+ return;
+ }
+ auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
+ BlockedImageUnionFindFunctor<T> union_find(
+ images.data(), num_rows, num_cols, forest.data(), rank.data());
+ while (union_find.can_merge()) {
+ union_find.merge_blocks();
+ int64 num_blocks_vertically = union_find.num_blocks_vertically();
+ int64 num_blocks_horizontally = union_find.num_blocks_horizontally();
+ // Merging each block calls union_down for each pixel in a row of the
+ // block, and union_right for each pixel in a column of the block. Assume
+ // 20 instructions for each call to union_down or union_right. find() may
+ // loop more while searching for the root, but this should not be very
+ // significant.
+ int cost = (union_find.block_height() + union_find.block_width()) * 20;
+ Shard(worker_threads->num_threads, worker_threads->workers,
+ num_images * num_blocks_vertically * num_blocks_horizontally, cost,
+ [&union_find, num_images, num_blocks_vertically,
+ num_blocks_horizontally](int64 start_block, int64 limit_block) {
+ for (int64 i = start_block; i < limit_block; i++) {
+ int64 block_x = i % num_blocks_horizontally;
+ int64 block_y =
+ (i / num_blocks_horizontally) % num_blocks_vertically;
+ int64 image =
+ i / (num_blocks_horizontally * num_blocks_vertically);
+ union_find.merge_internal_block_edges(image, block_y, block_x);
+ }
+ });
+ }
+ FindRootFunctor<CPUDevice, T>()(ctx->eigen_device<CPUDevice>(), output,
+ images.data(), union_find);
+ }
+};
+
+} // end namespace functor
+
+#define REGISTER_IMAGE_CONNECTED_COMPONENTS(TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("ImageConnectedComponents") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TYPE>("dtype"), \
+ ImageConnectedComponents<CPUDevice, TYPE>)
+// Connected components (arguably) make sense for number, bool, and string types
+TF_CALL_NUMBER_TYPES(REGISTER_IMAGE_CONNECTED_COMPONENTS);
+TF_CALL_bool(REGISTER_IMAGE_CONNECTED_COMPONENTS);
+TF_CALL_string(REGISTER_IMAGE_CONNECTED_COMPONENTS);
+#undef REGISTER_IMAGE_CONNECTED_COMPONENTS
+
+// TODO(ringwalt): Implement on GPU. We probably want to stick to the original
+// algorithm by Stava and Benes there for efficiency (computing small blocks in
+// shared memory in CUDA thread blocks, instead of starting with single-pixel
+// blocks).
+
+} // end namespace tensorflow
diff --git a/tensorflow/contrib/image/kernels/segmentation_ops.h b/tensorflow/contrib/image/kernels/segmentation_ops.h
new file mode 100644
index 0000000000..0957d5fd10
--- /dev/null
+++ b/tensorflow/contrib/image/kernels/segmentation_ops.h
@@ -0,0 +1,303 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_IMAGE_KERNELS_SEGMENTATION_OPS_H_
+#define TENSORFLOW_CONTRIB_IMAGE_KERNELS_SEGMENTATION_OPS_H_
+
+// Connected component analysis. The op is described in ../ops/image_ops.cc. A
+// description of the algorithm appears below.
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+namespace functor {
+
+template <typename T>
+bool is_nonzero(T value) {
+ return value != T(0);
+}
+
+template <>
+bool is_nonzero(string value) {
+ return value.size() != 0;
+}
+
+// Processes each pixel of an image for union-find, in parallel blocks. This is
+// loosely based on the algorithm in "GPU Computing Gems" by Ondrej Stava and
+// Bedrich Benes, available here:
+// http://hpcg.purdue.edu/bbenes/papers/Stava2011CCL.pdf
+// The bulk of the process uses blocks of each image, which have each been
+// processed separately. As long as there are multiple blocks in the image, we
+// double the height and width of the blocks, creating new blocks which each
+// consist of 2x2 previous sub-blocks. On each new block, we process adjacent
+// pixels from the previous sub-blocks serially. However, the new blocks are not
+// connected, so we can process each block in parallel.
+// The GPU algorithm first processes blocks of a fixed size in GPU shared
+// memory, with one image block per CUDA thread block. On the CPU, we just start
+// with a block size of a single pixel, and borrow the rest of the algorithm
+// unchanged.
+template <typename T>
+class BlockedImageUnionFindFunctor {
+ public:
+ using OutputType = int64;
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlockedImageUnionFindFunctor(
+ const T* images, const int64 num_rows, const int64 num_cols,
+ OutputType* forest, OutputType* rank)
+ : images_(images),
+ num_rows_(num_rows),
+ num_cols_(num_cols),
+ block_height_(1),
+ block_width_(1),
+ forest_(forest),
+ rank_(rank) {}
+
+ // Returns the root of the tree that the pixel at the given index belongs to.
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE OutputType
+ find(OutputType index) const {
+ while (forest_[index] != index) {
+ index = forest_[index];
+ }
+ return index;
+ }
+
+ // Returns the number of blocks along the y axis.
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int64 num_blocks_vertically() const {
+ return (num_rows_ + block_height_ - 1) / block_height_;
+ }
+
+ // Returns the number of blocks along the x axis.
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int64 num_blocks_horizontally() const {
+ return (num_cols_ + block_width_ - 1) / block_width_;
+ }
+
+ // Returns the total number of blocks in each image.
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int64 num_blocks() const {
+ return num_blocks_vertically() * num_blocks_horizontally();
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int64 block_height() const {
+ return block_height_;
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int64 block_width() const {
+ return block_width_;
+ }
+
+ // Returns whether we may merge again (the image contains more than one
+ // block).
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool can_merge() const {
+ return block_height_ < num_rows_ || block_width_ < num_cols_;
+ }
+
+ // Doubles the block size. After this method, you must call
+ // `merge_internal_block_edges` for each image and each *new* block's xy
+ // coordinates (typically in parallel).
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void merge_blocks() {
+ block_height_ *= 2;
+ block_width_ *= 2;
+ }
+
+ // Processes pairs of pixels within the block which were adjacent in the four
+ // sub-blocks. This must be done at each stage so that the connected
+ // components in each block are joined correctly.
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void merge_internal_block_edges(
+ int64 image_index, int64 block_vertical_index,
+ int64 block_horizontal_index) const {
+ int64 block_start_y = block_vertical_index * block_height_;
+ int64 block_start_x = block_horizontal_index * block_width_;
+ // Merge the 4 sub-blocks horizontally (fixing the vertical seam).
+ int64 block_center_x = block_start_x + block_width_ / 2 - 1;
+ if (0 <= block_center_x && block_center_x + 1 < num_cols_) {
+ int64 merge_blocks_limit_y =
+ std::min(num_rows_, block_start_y + block_height_);
+ for (int64 y = block_start_y; y < merge_blocks_limit_y; y++) {
+ union_right(image_index, y, block_center_x);
+ }
+ }
+ // Merge the 4 sub-blocks vertically (fixing the horizontal seam).
+ int64 block_center_y = block_start_y + block_height_ / 2 - 1;
+ if (0 <= block_center_y && block_center_y + 1 < num_rows_) {
+ int64 merge_blocks_limit_x =
+ std::min(num_cols_, block_start_x + block_width_);
+ for (int64 x = block_start_x; x < merge_blocks_limit_x; x++) {
+ union_down(image_index, block_center_y, x);
+ }
+ }
+ }
+
+ private:
+ // The input image(s).
+ const T* const images_;
+ const int64 num_rows_;
+ const int64 num_cols_;
+ // Current height of each sub-block of the image.
+ int64 block_height_;
+ // Current width of each sub-block of the image.
+ int64 block_width_;
+ // Union-find forest. This has the same size as `images_`, and each entry
+ // holds the index of its parent in `images_` (roots hold their own index).
+ // Cycles should not occur.
+ OutputType* const forest_;
+ // Union-find rank of each pixel.
+ OutputType* const rank_;
+
+ // Unions the pixel with the pixel below it if applicable (both pixels are
+ // true, and the pixel is not in the last row).
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void union_down(OutputType batch,
+ OutputType row,
+ OutputType col) const {
+ T pixel = read_pixel(batch, row, col);
+ if (is_nonzero<T>(pixel)) {
+ const int64 index_a = col + num_cols_ * (row + num_rows_ * batch);
+ if (row + 1 < num_rows_ && read_pixel(batch, row + 1, col) == pixel) {
+ const int64 index_b = col + num_cols_ * (row + 1 + num_rows_ * batch);
+ do_union(index_a, index_b);
+ }
+ }
+ }
+
+ // Unions the pixel with the pixel to the right of it if applicable.
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void union_right(OutputType batch,
+ OutputType row,
+ OutputType col) const {
+ T pixel = read_pixel(batch, row, col);
+ if (is_nonzero<T>(pixel)) {
+ const int64 index_a = col + num_cols_ * (row + num_rows_ * batch);
+ if (col + 1 < num_cols_ && read_pixel(batch, row, col + 1) == pixel) {
+ const int64 index_b = col + 1 + num_cols_ * (row + num_rows_ * batch);
+ do_union(index_a, index_b);
+ }
+ }
+ }
+
+ // Reads a pixel value in the images.
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
+ read_pixel(const OutputType batch, const OutputType row,
+ const OutputType col) const {
+ return images_[col + num_cols_ * (row + num_rows_ * batch)];
+ }
+
+ // Unions the trees that the two pixels belong to, using their index in the
+ // `images_` array.
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void do_union(
+ OutputType index_a, OutputType index_b) const {
+ // Find the roots of index_a and index_b in the forest, and make one the
+ // child of the other.
+ index_a = find(index_a);
+ index_b = find(index_b);
+ const OutputType rank_a = rank_[index_a];
+ const OutputType rank_b = rank_[index_b];
+ OutputType parent, child;
+ if (index_a == index_b) {
+ return;
+ } else if (rank_a < rank_b) {
+ parent = index_a;
+ child = index_b;
+ } else {
+ parent = index_b;
+ child = index_a;
+ rank_[parent]++;
+ }
+ forest_[child] = parent;
+ }
+};
+
+// Runs the ImageUnionFindFunctor on all pixels. Will require different CPU and
+// GPU implementations.
+template <typename Device, typename T>
+class ImageConnectedComponentsFunctor {
+ public:
+ using OutputType = typename BlockedImageUnionFindFunctor<T>::OutputType;
+
+ void operator()(OpKernelContext* ctx,
+ typename TTypes<T, 3>::ConstTensor images,
+ typename TTypes<OutputType, 3>::Tensor forest,
+ typename TTypes<OutputType, 3>::Tensor rank);
+};
+
+// Fills a flat Tensor with indices from 0 to n - 1.
+template <typename Device>
+class TensorRangeFunctor {
+ public:
+ using OutputType = typename BlockedImageUnionFindFunctor<bool>::OutputType;
+
+ void operator()(const Device& device,
+ typename TTypes<OutputType>::Flat tensor) {
+ tensor.device(device) = tensor.generate(TensorRangeGenerator());
+ }
+
+ private:
+ class TensorRangeGenerator {
+ public:
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE OutputType
+ operator()(const Eigen::array<Eigen::DenseIndex, 1>& coords) const {
+ return coords[0];
+ }
+ };
+};
+
+// Given the union-find forest, generates the root index for each node. This
+// gives us arbitrary, usually non-consecutive ids for each connected component.
+// The ids are massaged in Python to get deterministic, consecutive ids.
+template <typename Device, typename T>
+class FindRootFunctor {
+ public:
+ using OutputType = typename BlockedImageUnionFindFunctor<T>::OutputType;
+
+ void operator()(const Device& device,
+ typename TTypes<OutputType>::Flat component_ids,
+ const T* images,
+ const BlockedImageUnionFindFunctor<T>& union_find) {
+ component_ids.device(device) =
+ component_ids.generate(FindRootGenerator(images, union_find));
+ }
+
+ private:
+ class FindRootGenerator {
+ const T* const images_;
+ const BlockedImageUnionFindFunctor<T> union_find_;
+
+ public:
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE FindRootGenerator(
+ const T* images, BlockedImageUnionFindFunctor<T> union_find)
+ : images_(images), union_find_(union_find) {}
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE OutputType
+ operator()(const Eigen::array<Eigen::DenseIndex, 1>& coords) const {
+ if (is_nonzero<T>(images_[coords[0]])) {
+ // True pixels have an arbitrary segment id > 0. The segment ids will be
+ // made contiguous later.
+ return union_find_.find(coords[0]) + 1;
+ } else {
+ // False pixels have a segment of 0.
+ return 0;
+ }
+ }
+ };
+};
+
+} // end namespace functor
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_IMAGE_KERNELS_SEGMENTATION_OPS_H_
diff --git a/tensorflow/contrib/image/ops/image_ops.cc b/tensorflow/contrib/image/ops/image_ops.cc
index 4527fdd87a..68771b3d05 100644
--- a/tensorflow/contrib/image/ops/image_ops.cc
+++ b/tensorflow/contrib/image/ops/image_ops.cc
@@ -98,4 +98,34 @@ col_to_row_match_indices: A vector of length num_columns, which is the number
`col_to_row_match_indices[j]`.
)doc");
+REGISTER_OP("ImageConnectedComponents")
+ .Input("image: dtype")
+ .Output("components: int64")
+ .Attr(
+ "dtype: {int64, int32, uint16, int16, uint8, int8, half, float, "
+ "double, bool, string}")
+ .SetShapeFn([](InferenceContext* c) {
+ return shape_inference::UnchangedShape(c);
+ })
+ .Doc(R"doc(
+Find the connected components of image(s).
+
+For each image (along the 0th axis), all connected components of adjacent pixels
+with the same non-zero value are detected and given unique ids.
+
+The returned `components` tensor has 0s for the zero pixels of `images`, and
+arbitrary nonzero ids for the connected components of nonzero values. Ids are
+unique across all of the images, and are in row-major order by the first pixel
+in the component.
+
+Uses union-find with union by rank but not path compression, giving a runtime of
+`O(n log n)`. See:
+ https://en.wikipedia.org/wiki/Disjoint-set_data_structure#Time_Complexity
+
+image: Image(s) with shape (N, H, W).
+components: Component ids for each pixel in "image". Same shape as "image". Zero
+ pixels all have an output of 0, and all components of adjacent pixels with
+ the same value are given consecutive ids, starting from 1.
+)doc");
+
} // namespace tensorflow
diff --git a/tensorflow/contrib/image/python/kernel_tests/segmentation_test.py b/tensorflow/contrib/image/python/kernel_tests/segmentation_test.py
new file mode 100644
index 0000000000..48066cbace
--- /dev/null
+++ b/tensorflow/contrib/image/python/kernel_tests/segmentation_test.py
@@ -0,0 +1,189 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for connected component analysis."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import logging
+
+import numpy as np
+
+from tensorflow.contrib.image.python.ops import image_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import googletest
+
+# Image for testing connected_components, with a single, winding component.
+SNAKE = np.asarray(
+ [[0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 1, 1, 1, 0, 0, 0, 0],
+ [0, 0, 0, 0, 1, 1, 1, 1, 0],
+ [0, 0, 0, 0, 0, 0, 0, 1, 0],
+ [0, 1, 1, 1, 1, 1, 1, 1, 0],
+ [0, 1, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 0, 1, 1, 1, 1, 1, 0],
+ [0, 1, 0, 0, 0, 0, 0, 1, 0],
+ [0, 1, 1, 1, 1, 1, 1, 1, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0]]) # pyformat: disable
+
+
+class SegmentationTest(test_util.TensorFlowTestCase):
+
+ def testDisconnected(self):
+ arr = math_ops.cast(
+ [[1, 0, 0, 1, 0, 0, 0, 0, 1],
+ [0, 1, 0, 0, 0, 1, 0, 1, 0],
+ [1, 0, 1, 0, 0, 0, 1, 0, 0],
+ [0, 0, 0, 0, 1, 0, 0, 0, 0],
+ [0, 0, 1, 0, 0, 0, 0, 0, 0]],
+ dtypes.bool) # pyformat: disable
+ expected = (
+ [[1, 0, 0, 2, 0, 0, 0, 0, 3],
+ [0, 4, 0, 0, 0, 5, 0, 6, 0],
+ [7, 0, 8, 0, 0, 0, 9, 0, 0],
+ [0, 0, 0, 0, 10, 0, 0, 0, 0],
+ [0, 0, 11, 0, 0, 0, 0, 0, 0]]) # pyformat: disable
+ with self.test_session():
+ self.assertAllEqual(image_ops.connected_components(arr).eval(), expected)
+
+ def testSimple(self):
+ arr = [[0, 1, 0], [1, 1, 1], [0, 1, 0]]
+ with self.test_session():
+ # Single component with id 1.
+ self.assertAllEqual(
+ image_ops.connected_components(math_ops.cast(
+ arr, dtypes.bool)).eval(), arr)
+
+ def testSnake(self):
+ with self.test_session():
+ # Single component with id 1.
+ self.assertAllEqual(
+ image_ops.connected_components(math_ops.cast(
+ SNAKE, dtypes.bool)).eval(), SNAKE)
+
+ def testSnake_disconnected(self):
+ for i in range(SNAKE.shape[0]):
+ for j in range(SNAKE.shape[1]):
+ with self.test_session():
+ # If we disconnect any part of the snake except for the endpoints,
+ # there will be 2 components.
+ if SNAKE[i, j] and (i, j) not in [(1, 1), (6, 3)]:
+ disconnected_snake = SNAKE.copy()
+ disconnected_snake[i, j] = 0
+ components = image_ops.connected_components(
+ math_ops.cast(disconnected_snake, dtypes.bool)).eval()
+ self.assertEqual(components.max(), 2, 'disconnect (%d, %d)' % (i,
+ j))
+ bins = np.bincount(components.ravel())
+ # Nonzero number of pixels labeled 0, 1, or 2.
+ self.assertGreater(bins[0], 0)
+ self.assertGreater(bins[1], 0)
+ self.assertGreater(bins[2], 0)
+
+ def testMultipleImages(self):
+ images = [[[1, 1, 1, 1],
+ [1, 0, 0, 1],
+ [1, 0, 0, 1],
+ [1, 1, 1, 1]],
+ [[1, 0, 0, 1],
+ [0, 0, 0, 0],
+ [0, 0, 0, 0],
+ [1, 0, 0, 1]],
+ [[1, 1, 0, 1],
+ [0, 1, 1, 0],
+ [1, 0, 1, 0],
+ [0, 0, 1, 1]]] # pyformat: disable
+ expected = [[[1, 1, 1, 1],
+ [1, 0, 0, 1],
+ [1, 0, 0, 1],
+ [1, 1, 1, 1]],
+ [[2, 0, 0, 3],
+ [0, 0, 0, 0],
+ [0, 0, 0, 0],
+ [4, 0, 0, 5]],
+ [[6, 6, 0, 7],
+ [0, 6, 6, 0],
+ [8, 0, 6, 0],
+ [0, 0, 6, 6]]] # pyformat: disable
+ with self.test_session():
+ self.assertAllEqual(
+ image_ops.connected_components(math_ops.cast(
+ images, dtypes.bool)).eval(), expected)
+
+ def testZeros(self):
+ with self.test_session():
+ self.assertAllEqual(
+ image_ops.connected_components(
+ array_ops.zeros((100, 20, 50), dtypes.bool)).eval(),
+ np.zeros((100, 20, 50)))
+
+ def testOnes(self):
+ with self.test_session():
+ self.assertAllEqual(
+ image_ops.connected_components(
+ array_ops.ones((100, 20, 50), dtypes.bool)).eval(),
+ np.tile(np.arange(100)[:, None, None] + 1, [1, 20, 50]))
+
+ def testOnes_small(self):
+ with self.test_session():
+ self.assertAllEqual(
+ image_ops.connected_components(array_ops.ones((3, 5),
+ dtypes.bool)).eval(),
+ np.ones((3, 5)))
+
+ def testRandom_scipy(self):
+ np.random.seed(42)
+ images = np.random.randint(0, 2, size=(10, 100, 200)).astype(np.bool)
+ expected = connected_components_reference_implementation(images)
+ if expected is None:
+ return
+ with self.test_session():
+ self.assertAllEqual(
+ image_ops.connected_components(images).eval(), expected)
+
+
+def connected_components_reference_implementation(images):
+ try:
+ # pylint: disable=g-import-not-at-top
+ from scipy.ndimage import measurements
+ except ImportError:
+ logging.exception('Skipping test method because scipy could not be loaded')
+ return
+ image_or_images = np.asarray(images)
+ if len(image_or_images.shape) == 2:
+ images = image_or_images[None, :, :]
+ elif len(image_or_images.shape) == 3:
+ images = image_or_images
+ components = np.asarray([measurements.label(image)[0] for image in images])
+ # Get the count of nonzero ids for each image, and offset each image's nonzero
+ # ids using the cumulative sum.
+ num_ids_per_image = components.reshape(
+ [-1, components.shape[1] * components.shape[2]]).max(axis=-1)
+ positive_id_start_per_image = np.cumsum(num_ids_per_image)
+ for i in range(components.shape[0]):
+ new_id_start = positive_id_start_per_image[i - 1] if i > 0 else 0
+ components[i, components[i] > 0] += new_id_start
+ if len(image_or_images.shape) == 2:
+ return components[0, :, :]
+ else:
+ return components
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py
index faedee6f87..63377ae503 100644
--- a/tensorflow/contrib/image/python/ops/image_ops.py
+++ b/tensorflow/contrib/image/python/ops/image_ops.py
@@ -24,6 +24,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import resource_loader
@@ -34,6 +35,7 @@ _image_ops_so = loader.load_op_library(
_IMAGE_DTYPES = set(
[dtypes.uint8, dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64])
+ops.RegisterShape("ImageConnectedComponents")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn)
@@ -395,4 +397,72 @@ def bipartite_match(distance_mat,
return result
+def connected_components(images):
+ """Labels the connected components in a batch of images.
+
+ A component is a set of pixels in a single input image, which are all adjacent
+ and all have the same non-zero value. The components using a squared
+ connectivity of one (all True entries are joined with their neighbors above,
+ below, left, and right). Components across all images have consecutive ids 1
+ through n. Components are labeled according to the first pixel of the
+ component appearing in row-major order (lexicographic order by
+ image_index_in_batch, row, col). Zero entries all have an output id of 0.
+
+ This op is equivalent with `scipy.ndimage.measurements.label` on a 2D array
+ with the default structuring element (which is the connectivity used here).
+
+ Args:
+ images: A 2D (H, W) or 3D (N, H, W) Tensor of boolean image(s).
+
+ Returns:
+ Components with the same shape as `images`. False entries in `images` have
+ value 0, and all True entries map to a component id > 0.
+
+ Raises:
+ TypeError: if `images` is not 2D or 3D.
+ """
+ with ops.name_scope("connected_components"):
+ image_or_images = ops.convert_to_tensor(images, name="images")
+ if len(image_or_images.get_shape()) == 2:
+ images = image_or_images[None, :, :]
+ elif len(image_or_images.get_shape()) == 3:
+ images = image_or_images
+ else:
+ raise TypeError(
+ "images should have rank 2 (HW) or 3 (NHW). Static shape is %s" %
+ image_or_images.get_shape())
+ components = gen_image_ops.image_connected_components(images)
+
+ # TODO(ringwalt): Component id renaming should be done in the op, to avoid
+ # constructing multiple additional large tensors.
+ components_flat = array_ops.reshape(components, [-1])
+ unique_ids, id_index = array_ops.unique(components_flat)
+ id_is_zero = array_ops.where(math_ops.equal(unique_ids, 0))[:, 0]
+ # Map each nonzero id to consecutive values.
+ nonzero_consecutive_ids = math_ops.range(
+ array_ops.shape(unique_ids)[0] - array_ops.shape(id_is_zero)[0]) + 1
+
+ def no_zero():
+ # No need to insert a zero into the ids.
+ return nonzero_consecutive_ids
+
+ def has_zero():
+ # Insert a zero in the consecutive ids where zero appears in unique_ids.
+ # id_is_zero has length 1.
+ zero_id_ind = math_ops.to_int32(id_is_zero[0])
+ ids_before = nonzero_consecutive_ids[:zero_id_ind]
+ ids_after = nonzero_consecutive_ids[zero_id_ind:]
+ return array_ops.concat([ids_before, [0], ids_after], axis=0)
+
+ new_ids = control_flow_ops.cond(
+ math_ops.equal(array_ops.shape(id_is_zero)[0], 0), no_zero, has_zero)
+ components = array_ops.reshape(
+ array_ops.gather(new_ids, id_index), array_ops.shape(components))
+ if len(image_or_images.get_shape()) == 2:
+ return components[0, :, :]
+ else:
+ return components
+
+
ops.NotDifferentiable("BipartiteMatch")
+ops.NotDifferentiable("ImageConnectedComponents")