From a41ab15aeea526355d807fcf35e057ece0e35bc4 Mon Sep 17 00:00:00 2001 From: Dan Ringwalt Date: Wed, 17 Jan 2018 11:33:12 -0800 Subject: Add tf.contrib.image.connected_components. Comparable to scipy.ndimage.measurements.label. PiperOrigin-RevId: 182244926 --- tensorflow/contrib/cmake/tf_core_kernels.cmake | 1 + tensorflow/contrib/image/BUILD | 21 ++ tensorflow/contrib/image/__init__.py | 14 +- .../contrib/image/kernels/segmentation_ops.cc | 139 ++++++++++ .../contrib/image/kernels/segmentation_ops.h | 303 +++++++++++++++++++++ tensorflow/contrib/image/ops/image_ops.cc | 30 ++ .../image/python/kernel_tests/segmentation_test.py | 189 +++++++++++++ tensorflow/contrib/image/python/ops/image_ops.py | 70 +++++ 8 files changed, 766 insertions(+), 1 deletion(-) create mode 100644 tensorflow/contrib/image/kernels/segmentation_ops.cc create mode 100644 tensorflow/contrib/image/kernels/segmentation_ops.h create mode 100644 tensorflow/contrib/image/python/kernel_tests/segmentation_test.py (limited to 'tensorflow') diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index d3b6c0bdd3..90a724a573 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -79,6 +79,7 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/bipartite_match_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/image_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/segmentation_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/image/ops/distort_image_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/image/ops/image_ops.cc" diff --git a/tensorflow/contrib/image/BUILD b/tensorflow/contrib/image/BUILD index ce2b279e51..3ff02e085e 100755 --- a/tensorflow/contrib/image/BUILD +++ b/tensorflow/contrib/image/BUILD @@ -14,6 +14,7 @@ load( "tf_gen_op_libs", "tf_gen_op_wrapper_py", "tf_kernel_library", + "tf_py_test", ) load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") @@ -24,6 +25,8 @@ tf_custom_op_library( "kernels/bipartite_match_op.cc", "kernels/image_ops.cc", "kernels/image_ops.h", + "kernels/segmentation_ops.cc", + "kernels/segmentation_ops.h", "ops/image_ops.cc", ], gpu_srcs = [ @@ -38,6 +41,8 @@ tf_kernel_library( "kernels/bipartite_match_op.cc", "kernels/image_ops.cc", "kernels/image_ops.h", + "kernels/segmentation_ops.cc", + "kernels/segmentation_ops.h", ], gpu_srcs = [ "kernels/image_ops_gpu.cu.cc", @@ -78,6 +83,7 @@ tf_custom_op_py_library( "//tensorflow/python:array_ops", "//tensorflow/python:common_shapes", "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", @@ -188,6 +194,21 @@ cuda_py_test( ], ) +tf_py_test( + name = "segmentation_test", + size = "medium", + srcs = ["python/kernel_tests/segmentation_test.py"], + additional_deps = [ + ":distort_image_py", + ":image_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + tf_custom_op_library( name = "python/ops/_single_image_random_dot_stereograms.so", srcs = [ diff --git a/tensorflow/contrib/image/__init__.py b/tensorflow/contrib/image/__init__.py index d030dffade..cc8ed117ba 100755 --- a/tensorflow/contrib/image/__init__.py +++ b/tensorflow/contrib/image/__init__.py @@ -20,6 +20,8 @@ This module provides functions for image manipulation; currently, chrominance transformas (including changing saturation and hue) in YIQ space and projective transforms (including rotation) are supported. +## Image Transformation `Ops` + @@angles_to_projective_transforms @@compose_transforms @@adjust_yiq_hsv @@ -28,19 +30,29 @@ projective transforms (including rotation) are supported. @@transform @@translate @@translations_to_projective_transforms + +## Image Segmentation `Ops` + +@@connected_components + +## Matching `Ops` + @@bipartite_match + +## Random Dot Stereogram `Ops` + @@single_image_random_dot_stereograms """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=line-too-long from tensorflow.contrib.image.python.ops.distort_image_ops import adjust_hsv_in_yiq from tensorflow.contrib.image.python.ops.distort_image_ops import random_hsv_in_yiq from tensorflow.contrib.image.python.ops.image_ops import angles_to_projective_transforms from tensorflow.contrib.image.python.ops.image_ops import compose_transforms +from tensorflow.contrib.image.python.ops.image_ops import connected_components from tensorflow.contrib.image.python.ops.image_ops import rotate from tensorflow.contrib.image.python.ops.image_ops import transform from tensorflow.contrib.image.python.ops.image_ops import translate diff --git a/tensorflow/contrib/image/kernels/segmentation_ops.cc b/tensorflow/contrib/image/kernels/segmentation_ops.cc new file mode 100644 index 0000000000..fe8bf6e21c --- /dev/null +++ b/tensorflow/contrib/image/kernels/segmentation_ops.cc @@ -0,0 +1,139 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs for ImageConnectedComponents in ../ops/image_ops.cc, and description +// of the algorithm in segmentation_ops.h. + +#define EIGEN_USE_THREADS + +#include "tensorflow/contrib/image/kernels/segmentation_ops.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +using tensorflow::functor::BlockedImageUnionFindFunctor; +using tensorflow::functor::FindRootFunctor; +using tensorflow::functor::ImageConnectedComponentsFunctor; +using tensorflow::functor::TensorRangeFunctor; + +using OutputType = typename BlockedImageUnionFindFunctor::OutputType; + +// Computes connected components on batches of 2D images. +template +class ImageConnectedComponents : public OpKernel { + public: + explicit ImageConnectedComponents(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& images_t = ctx->input(0); + OP_REQUIRES(ctx, images_t.shape().dims() == 3, + errors::InvalidArgument("Input images must have rank 3")); + Tensor forest_t, rank_t; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(tensorflow::DT_INT64, + images_t.shape(), &forest_t)); + OP_REQUIRES_OK(ctx, ctx->allocate_temp(tensorflow::DT_INT64, + images_t.shape(), &rank_t)); + Tensor* output_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t)); + + // Fill forest with values from 0 to n - 1, so that each node points to + // itself. + TensorRangeFunctor()(ctx->eigen_device(), + forest_t.flat()); + auto rank = rank_t.tensor(); + rank.device(ctx->eigen_device()) = rank.constant(OutputType(0)); + + const auto images = images_t.tensor(); + auto forest = forest_t.tensor(); + ImageConnectedComponentsFunctor()( + ctx, output_t->flat(), images, forest, rank); + } +}; + +using CPUDevice = Eigen::ThreadPoolDevice; + +namespace functor { + +// Connected components CPU implementation. See `segmentation_ops.h` for a +// description of the algorithm. +template +struct ImageConnectedComponentsFunctor { + void operator()(OpKernelContext* ctx, + typename TTypes::Flat output, + typename TTypes::ConstTensor images, + typename TTypes::Tensor forest, + typename TTypes::Tensor rank) { + const int64 num_images = images.dimension(0), + num_rows = images.dimension(1), num_cols = images.dimension(2), + num_elements = images.size(); + // Bail out early for an empty image--no work to do. + if (num_elements == 0) { + return; + } + auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); + BlockedImageUnionFindFunctor union_find( + images.data(), num_rows, num_cols, forest.data(), rank.data()); + while (union_find.can_merge()) { + union_find.merge_blocks(); + int64 num_blocks_vertically = union_find.num_blocks_vertically(); + int64 num_blocks_horizontally = union_find.num_blocks_horizontally(); + // Merging each block calls union_down for each pixel in a row of the + // block, and union_right for each pixel in a column of the block. Assume + // 20 instructions for each call to union_down or union_right. find() may + // loop more while searching for the root, but this should not be very + // significant. + int cost = (union_find.block_height() + union_find.block_width()) * 20; + Shard(worker_threads->num_threads, worker_threads->workers, + num_images * num_blocks_vertically * num_blocks_horizontally, cost, + [&union_find, num_images, num_blocks_vertically, + num_blocks_horizontally](int64 start_block, int64 limit_block) { + for (int64 i = start_block; i < limit_block; i++) { + int64 block_x = i % num_blocks_horizontally; + int64 block_y = + (i / num_blocks_horizontally) % num_blocks_vertically; + int64 image = + i / (num_blocks_horizontally * num_blocks_vertically); + union_find.merge_internal_block_edges(image, block_y, block_x); + } + }); + } + FindRootFunctor()(ctx->eigen_device(), output, + images.data(), union_find); + } +}; + +} // end namespace functor + +#define REGISTER_IMAGE_CONNECTED_COMPONENTS(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("ImageConnectedComponents") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype"), \ + ImageConnectedComponents) +// Connected components (arguably) make sense for number, bool, and string types +TF_CALL_NUMBER_TYPES(REGISTER_IMAGE_CONNECTED_COMPONENTS); +TF_CALL_bool(REGISTER_IMAGE_CONNECTED_COMPONENTS); +TF_CALL_string(REGISTER_IMAGE_CONNECTED_COMPONENTS); +#undef REGISTER_IMAGE_CONNECTED_COMPONENTS + +// TODO(ringwalt): Implement on GPU. We probably want to stick to the original +// algorithm by Stava and Benes there for efficiency (computing small blocks in +// shared memory in CUDA thread blocks, instead of starting with single-pixel +// blocks). + +} // end namespace tensorflow diff --git a/tensorflow/contrib/image/kernels/segmentation_ops.h b/tensorflow/contrib/image/kernels/segmentation_ops.h new file mode 100644 index 0000000000..0957d5fd10 --- /dev/null +++ b/tensorflow/contrib/image/kernels/segmentation_ops.h @@ -0,0 +1,303 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IMAGE_KERNELS_SEGMENTATION_OPS_H_ +#define TENSORFLOW_CONTRIB_IMAGE_KERNELS_SEGMENTATION_OPS_H_ + +// Connected component analysis. The op is described in ../ops/image_ops.cc. A +// description of the algorithm appears below. + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { + +namespace functor { + +template +bool is_nonzero(T value) { + return value != T(0); +} + +template <> +bool is_nonzero(string value) { + return value.size() != 0; +} + +// Processes each pixel of an image for union-find, in parallel blocks. This is +// loosely based on the algorithm in "GPU Computing Gems" by Ondrej Stava and +// Bedrich Benes, available here: +// http://hpcg.purdue.edu/bbenes/papers/Stava2011CCL.pdf +// The bulk of the process uses blocks of each image, which have each been +// processed separately. As long as there are multiple blocks in the image, we +// double the height and width of the blocks, creating new blocks which each +// consist of 2x2 previous sub-blocks. On each new block, we process adjacent +// pixels from the previous sub-blocks serially. However, the new blocks are not +// connected, so we can process each block in parallel. +// The GPU algorithm first processes blocks of a fixed size in GPU shared +// memory, with one image block per CUDA thread block. On the CPU, we just start +// with a block size of a single pixel, and borrow the rest of the algorithm +// unchanged. +template +class BlockedImageUnionFindFunctor { + public: + using OutputType = int64; + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlockedImageUnionFindFunctor( + const T* images, const int64 num_rows, const int64 num_cols, + OutputType* forest, OutputType* rank) + : images_(images), + num_rows_(num_rows), + num_cols_(num_cols), + block_height_(1), + block_width_(1), + forest_(forest), + rank_(rank) {} + + // Returns the root of the tree that the pixel at the given index belongs to. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE OutputType + find(OutputType index) const { + while (forest_[index] != index) { + index = forest_[index]; + } + return index; + } + + // Returns the number of blocks along the y axis. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int64 num_blocks_vertically() const { + return (num_rows_ + block_height_ - 1) / block_height_; + } + + // Returns the number of blocks along the x axis. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int64 num_blocks_horizontally() const { + return (num_cols_ + block_width_ - 1) / block_width_; + } + + // Returns the total number of blocks in each image. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int64 num_blocks() const { + return num_blocks_vertically() * num_blocks_horizontally(); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int64 block_height() const { + return block_height_; + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int64 block_width() const { + return block_width_; + } + + // Returns whether we may merge again (the image contains more than one + // block). + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool can_merge() const { + return block_height_ < num_rows_ || block_width_ < num_cols_; + } + + // Doubles the block size. After this method, you must call + // `merge_internal_block_edges` for each image and each *new* block's xy + // coordinates (typically in parallel). + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void merge_blocks() { + block_height_ *= 2; + block_width_ *= 2; + } + + // Processes pairs of pixels within the block which were adjacent in the four + // sub-blocks. This must be done at each stage so that the connected + // components in each block are joined correctly. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void merge_internal_block_edges( + int64 image_index, int64 block_vertical_index, + int64 block_horizontal_index) const { + int64 block_start_y = block_vertical_index * block_height_; + int64 block_start_x = block_horizontal_index * block_width_; + // Merge the 4 sub-blocks horizontally (fixing the vertical seam). + int64 block_center_x = block_start_x + block_width_ / 2 - 1; + if (0 <= block_center_x && block_center_x + 1 < num_cols_) { + int64 merge_blocks_limit_y = + std::min(num_rows_, block_start_y + block_height_); + for (int64 y = block_start_y; y < merge_blocks_limit_y; y++) { + union_right(image_index, y, block_center_x); + } + } + // Merge the 4 sub-blocks vertically (fixing the horizontal seam). + int64 block_center_y = block_start_y + block_height_ / 2 - 1; + if (0 <= block_center_y && block_center_y + 1 < num_rows_) { + int64 merge_blocks_limit_x = + std::min(num_cols_, block_start_x + block_width_); + for (int64 x = block_start_x; x < merge_blocks_limit_x; x++) { + union_down(image_index, block_center_y, x); + } + } + } + + private: + // The input image(s). + const T* const images_; + const int64 num_rows_; + const int64 num_cols_; + // Current height of each sub-block of the image. + int64 block_height_; + // Current width of each sub-block of the image. + int64 block_width_; + // Union-find forest. This has the same size as `images_`, and each entry + // holds the index of its parent in `images_` (roots hold their own index). + // Cycles should not occur. + OutputType* const forest_; + // Union-find rank of each pixel. + OutputType* const rank_; + + // Unions the pixel with the pixel below it if applicable (both pixels are + // true, and the pixel is not in the last row). + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void union_down(OutputType batch, + OutputType row, + OutputType col) const { + T pixel = read_pixel(batch, row, col); + if (is_nonzero(pixel)) { + const int64 index_a = col + num_cols_ * (row + num_rows_ * batch); + if (row + 1 < num_rows_ && read_pixel(batch, row + 1, col) == pixel) { + const int64 index_b = col + num_cols_ * (row + 1 + num_rows_ * batch); + do_union(index_a, index_b); + } + } + } + + // Unions the pixel with the pixel to the right of it if applicable. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void union_right(OutputType batch, + OutputType row, + OutputType col) const { + T pixel = read_pixel(batch, row, col); + if (is_nonzero(pixel)) { + const int64 index_a = col + num_cols_ * (row + num_rows_ * batch); + if (col + 1 < num_cols_ && read_pixel(batch, row, col + 1) == pixel) { + const int64 index_b = col + 1 + num_cols_ * (row + num_rows_ * batch); + do_union(index_a, index_b); + } + } + } + + // Reads a pixel value in the images. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T + read_pixel(const OutputType batch, const OutputType row, + const OutputType col) const { + return images_[col + num_cols_ * (row + num_rows_ * batch)]; + } + + // Unions the trees that the two pixels belong to, using their index in the + // `images_` array. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void do_union( + OutputType index_a, OutputType index_b) const { + // Find the roots of index_a and index_b in the forest, and make one the + // child of the other. + index_a = find(index_a); + index_b = find(index_b); + const OutputType rank_a = rank_[index_a]; + const OutputType rank_b = rank_[index_b]; + OutputType parent, child; + if (index_a == index_b) { + return; + } else if (rank_a < rank_b) { + parent = index_a; + child = index_b; + } else { + parent = index_b; + child = index_a; + rank_[parent]++; + } + forest_[child] = parent; + } +}; + +// Runs the ImageUnionFindFunctor on all pixels. Will require different CPU and +// GPU implementations. +template +class ImageConnectedComponentsFunctor { + public: + using OutputType = typename BlockedImageUnionFindFunctor::OutputType; + + void operator()(OpKernelContext* ctx, + typename TTypes::ConstTensor images, + typename TTypes::Tensor forest, + typename TTypes::Tensor rank); +}; + +// Fills a flat Tensor with indices from 0 to n - 1. +template +class TensorRangeFunctor { + public: + using OutputType = typename BlockedImageUnionFindFunctor::OutputType; + + void operator()(const Device& device, + typename TTypes::Flat tensor) { + tensor.device(device) = tensor.generate(TensorRangeGenerator()); + } + + private: + class TensorRangeGenerator { + public: + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE OutputType + operator()(const Eigen::array& coords) const { + return coords[0]; + } + }; +}; + +// Given the union-find forest, generates the root index for each node. This +// gives us arbitrary, usually non-consecutive ids for each connected component. +// The ids are massaged in Python to get deterministic, consecutive ids. +template +class FindRootFunctor { + public: + using OutputType = typename BlockedImageUnionFindFunctor::OutputType; + + void operator()(const Device& device, + typename TTypes::Flat component_ids, + const T* images, + const BlockedImageUnionFindFunctor& union_find) { + component_ids.device(device) = + component_ids.generate(FindRootGenerator(images, union_find)); + } + + private: + class FindRootGenerator { + const T* const images_; + const BlockedImageUnionFindFunctor union_find_; + + public: + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE FindRootGenerator( + const T* images, BlockedImageUnionFindFunctor union_find) + : images_(images), union_find_(union_find) {} + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE OutputType + operator()(const Eigen::array& coords) const { + if (is_nonzero(images_[coords[0]])) { + // True pixels have an arbitrary segment id > 0. The segment ids will be + // made contiguous later. + return union_find_.find(coords[0]) + 1; + } else { + // False pixels have a segment of 0. + return 0; + } + } + }; +}; + +} // end namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IMAGE_KERNELS_SEGMENTATION_OPS_H_ diff --git a/tensorflow/contrib/image/ops/image_ops.cc b/tensorflow/contrib/image/ops/image_ops.cc index 4527fdd87a..68771b3d05 100644 --- a/tensorflow/contrib/image/ops/image_ops.cc +++ b/tensorflow/contrib/image/ops/image_ops.cc @@ -98,4 +98,34 @@ col_to_row_match_indices: A vector of length num_columns, which is the number `col_to_row_match_indices[j]`. )doc"); +REGISTER_OP("ImageConnectedComponents") + .Input("image: dtype") + .Output("components: int64") + .Attr( + "dtype: {int64, int32, uint16, int16, uint8, int8, half, float, " + "double, bool, string}") + .SetShapeFn([](InferenceContext* c) { + return shape_inference::UnchangedShape(c); + }) + .Doc(R"doc( +Find the connected components of image(s). + +For each image (along the 0th axis), all connected components of adjacent pixels +with the same non-zero value are detected and given unique ids. + +The returned `components` tensor has 0s for the zero pixels of `images`, and +arbitrary nonzero ids for the connected components of nonzero values. Ids are +unique across all of the images, and are in row-major order by the first pixel +in the component. + +Uses union-find with union by rank but not path compression, giving a runtime of +`O(n log n)`. See: + https://en.wikipedia.org/wiki/Disjoint-set_data_structure#Time_Complexity + +image: Image(s) with shape (N, H, W). +components: Component ids for each pixel in "image". Same shape as "image". Zero + pixels all have an output of 0, and all components of adjacent pixels with + the same value are given consecutive ids, starting from 1. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/contrib/image/python/kernel_tests/segmentation_test.py b/tensorflow/contrib/image/python/kernel_tests/segmentation_test.py new file mode 100644 index 0000000000..48066cbace --- /dev/null +++ b/tensorflow/contrib/image/python/kernel_tests/segmentation_test.py @@ -0,0 +1,189 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for connected component analysis.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging + +import numpy as np + +from tensorflow.contrib.image.python.ops import image_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + +# Image for testing connected_components, with a single, winding component. +SNAKE = np.asarray( + [[0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 1, 1, 1, 1, 1, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 1, 1, 1, 1, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 1, 1, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0]]) # pyformat: disable + + +class SegmentationTest(test_util.TensorFlowTestCase): + + def testDisconnected(self): + arr = math_ops.cast( + [[1, 0, 0, 1, 0, 0, 0, 0, 1], + [0, 1, 0, 0, 0, 1, 0, 1, 0], + [1, 0, 1, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0]], + dtypes.bool) # pyformat: disable + expected = ( + [[1, 0, 0, 2, 0, 0, 0, 0, 3], + [0, 4, 0, 0, 0, 5, 0, 6, 0], + [7, 0, 8, 0, 0, 0, 9, 0, 0], + [0, 0, 0, 0, 10, 0, 0, 0, 0], + [0, 0, 11, 0, 0, 0, 0, 0, 0]]) # pyformat: disable + with self.test_session(): + self.assertAllEqual(image_ops.connected_components(arr).eval(), expected) + + def testSimple(self): + arr = [[0, 1, 0], [1, 1, 1], [0, 1, 0]] + with self.test_session(): + # Single component with id 1. + self.assertAllEqual( + image_ops.connected_components(math_ops.cast( + arr, dtypes.bool)).eval(), arr) + + def testSnake(self): + with self.test_session(): + # Single component with id 1. + self.assertAllEqual( + image_ops.connected_components(math_ops.cast( + SNAKE, dtypes.bool)).eval(), SNAKE) + + def testSnake_disconnected(self): + for i in range(SNAKE.shape[0]): + for j in range(SNAKE.shape[1]): + with self.test_session(): + # If we disconnect any part of the snake except for the endpoints, + # there will be 2 components. + if SNAKE[i, j] and (i, j) not in [(1, 1), (6, 3)]: + disconnected_snake = SNAKE.copy() + disconnected_snake[i, j] = 0 + components = image_ops.connected_components( + math_ops.cast(disconnected_snake, dtypes.bool)).eval() + self.assertEqual(components.max(), 2, 'disconnect (%d, %d)' % (i, + j)) + bins = np.bincount(components.ravel()) + # Nonzero number of pixels labeled 0, 1, or 2. + self.assertGreater(bins[0], 0) + self.assertGreater(bins[1], 0) + self.assertGreater(bins[2], 0) + + def testMultipleImages(self): + images = [[[1, 1, 1, 1], + [1, 0, 0, 1], + [1, 0, 0, 1], + [1, 1, 1, 1]], + [[1, 0, 0, 1], + [0, 0, 0, 0], + [0, 0, 0, 0], + [1, 0, 0, 1]], + [[1, 1, 0, 1], + [0, 1, 1, 0], + [1, 0, 1, 0], + [0, 0, 1, 1]]] # pyformat: disable + expected = [[[1, 1, 1, 1], + [1, 0, 0, 1], + [1, 0, 0, 1], + [1, 1, 1, 1]], + [[2, 0, 0, 3], + [0, 0, 0, 0], + [0, 0, 0, 0], + [4, 0, 0, 5]], + [[6, 6, 0, 7], + [0, 6, 6, 0], + [8, 0, 6, 0], + [0, 0, 6, 6]]] # pyformat: disable + with self.test_session(): + self.assertAllEqual( + image_ops.connected_components(math_ops.cast( + images, dtypes.bool)).eval(), expected) + + def testZeros(self): + with self.test_session(): + self.assertAllEqual( + image_ops.connected_components( + array_ops.zeros((100, 20, 50), dtypes.bool)).eval(), + np.zeros((100, 20, 50))) + + def testOnes(self): + with self.test_session(): + self.assertAllEqual( + image_ops.connected_components( + array_ops.ones((100, 20, 50), dtypes.bool)).eval(), + np.tile(np.arange(100)[:, None, None] + 1, [1, 20, 50])) + + def testOnes_small(self): + with self.test_session(): + self.assertAllEqual( + image_ops.connected_components(array_ops.ones((3, 5), + dtypes.bool)).eval(), + np.ones((3, 5))) + + def testRandom_scipy(self): + np.random.seed(42) + images = np.random.randint(0, 2, size=(10, 100, 200)).astype(np.bool) + expected = connected_components_reference_implementation(images) + if expected is None: + return + with self.test_session(): + self.assertAllEqual( + image_ops.connected_components(images).eval(), expected) + + +def connected_components_reference_implementation(images): + try: + # pylint: disable=g-import-not-at-top + from scipy.ndimage import measurements + except ImportError: + logging.exception('Skipping test method because scipy could not be loaded') + return + image_or_images = np.asarray(images) + if len(image_or_images.shape) == 2: + images = image_or_images[None, :, :] + elif len(image_or_images.shape) == 3: + images = image_or_images + components = np.asarray([measurements.label(image)[0] for image in images]) + # Get the count of nonzero ids for each image, and offset each image's nonzero + # ids using the cumulative sum. + num_ids_per_image = components.reshape( + [-1, components.shape[1] * components.shape[2]]).max(axis=-1) + positive_id_start_per_image = np.cumsum(num_ids_per_image) + for i in range(components.shape[0]): + new_id_start = positive_id_start_per_image[i - 1] if i > 0 else 0 + components[i, components[i] > 0] += new_id_start + if len(image_or_images.shape) == 2: + return components[0, :, :] + else: + return components + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index faedee6f87..63377ae503 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import resource_loader @@ -34,6 +35,7 @@ _image_ops_so = loader.load_op_library( _IMAGE_DTYPES = set( [dtypes.uint8, dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64]) +ops.RegisterShape("ImageConnectedComponents")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn) @@ -395,4 +397,72 @@ def bipartite_match(distance_mat, return result +def connected_components(images): + """Labels the connected components in a batch of images. + + A component is a set of pixels in a single input image, which are all adjacent + and all have the same non-zero value. The components using a squared + connectivity of one (all True entries are joined with their neighbors above, + below, left, and right). Components across all images have consecutive ids 1 + through n. Components are labeled according to the first pixel of the + component appearing in row-major order (lexicographic order by + image_index_in_batch, row, col). Zero entries all have an output id of 0. + + This op is equivalent with `scipy.ndimage.measurements.label` on a 2D array + with the default structuring element (which is the connectivity used here). + + Args: + images: A 2D (H, W) or 3D (N, H, W) Tensor of boolean image(s). + + Returns: + Components with the same shape as `images`. False entries in `images` have + value 0, and all True entries map to a component id > 0. + + Raises: + TypeError: if `images` is not 2D or 3D. + """ + with ops.name_scope("connected_components"): + image_or_images = ops.convert_to_tensor(images, name="images") + if len(image_or_images.get_shape()) == 2: + images = image_or_images[None, :, :] + elif len(image_or_images.get_shape()) == 3: + images = image_or_images + else: + raise TypeError( + "images should have rank 2 (HW) or 3 (NHW). Static shape is %s" % + image_or_images.get_shape()) + components = gen_image_ops.image_connected_components(images) + + # TODO(ringwalt): Component id renaming should be done in the op, to avoid + # constructing multiple additional large tensors. + components_flat = array_ops.reshape(components, [-1]) + unique_ids, id_index = array_ops.unique(components_flat) + id_is_zero = array_ops.where(math_ops.equal(unique_ids, 0))[:, 0] + # Map each nonzero id to consecutive values. + nonzero_consecutive_ids = math_ops.range( + array_ops.shape(unique_ids)[0] - array_ops.shape(id_is_zero)[0]) + 1 + + def no_zero(): + # No need to insert a zero into the ids. + return nonzero_consecutive_ids + + def has_zero(): + # Insert a zero in the consecutive ids where zero appears in unique_ids. + # id_is_zero has length 1. + zero_id_ind = math_ops.to_int32(id_is_zero[0]) + ids_before = nonzero_consecutive_ids[:zero_id_ind] + ids_after = nonzero_consecutive_ids[zero_id_ind:] + return array_ops.concat([ids_before, [0], ids_after], axis=0) + + new_ids = control_flow_ops.cond( + math_ops.equal(array_ops.shape(id_is_zero)[0], 0), no_zero, has_zero) + components = array_ops.reshape( + array_ops.gather(new_ids, id_index), array_ops.shape(components)) + if len(image_or_images.get_shape()) == 2: + return components[0, :, :] + else: + return components + + ops.NotDifferentiable("BipartiteMatch") +ops.NotDifferentiable("ImageConnectedComponents") -- cgit v1.2.3