diff options
author | Tayo Oguntebi <tayo@google.com> | 2018-07-30 17:13:09 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-30 17:16:55 -0700 |
commit | 8566d9e6fa7dbe3660339befe8b0a3344d24ef2b (patch) | |
tree | 54143734f9b19a9ead1a23784f75dcd2009d1452 | |
parent | 8fee2e4b7c915d952332dc8cc9be7cfefea35162 (diff) |
Adds a NonMaxSuppressionV4 op, with a corresponding TF2XLA implementation.
PiperOrigin-RevId: 206673787
-rw-r--r-- | tensorflow/compiler/tests/image_ops_test.py | 136 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/image_ops.cc | 150 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/topk_op.cc | 27 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/lib/BUILD | 31 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/lib/sorting.cc | 46 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/lib/sorting.h | 31 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/lib/sorting_test.cc | 60 | ||||
-rw-r--r-- | tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionV4.pbtxt | 78 | ||||
-rw-r--r-- | tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionV4.pbtxt | 4 | ||||
-rw-r--r-- | tensorflow/core/kernels/non_max_suppression_op.cc | 111 | ||||
-rw-r--r-- | tensorflow/core/kernels/non_max_suppression_op_test.cc | 55 | ||||
-rw-r--r-- | tensorflow/core/ops/image_ops.cc | 39 | ||||
-rw-r--r-- | tensorflow/python/ops/image_ops_impl.py | 59 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/tensorflow.image.pbtxt | 4 |
15 files changed, 784 insertions, 48 deletions
diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 8b01ef96db..bf986ade06 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -26,6 +26,7 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.tests import xla_test +from tensorflow.python.compat import compat from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -579,5 +580,140 @@ class ResizeBilinearTest(xla_test.XLATestCase): large_tolerance=True) +class NonMaxSuppressionTest(xla_test.XLATestCase): + + def testNMS128From1024(self): + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + with compat.forward_compatibility_horizon(2018, 8, 8): + num_boxes = 1024 + boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4") + scores_np = np.random.normal(0.5, 0.1, (num_boxes,)).astype("f4") + + max_output_size = 128 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.0, dtype=np.float32) + + with self.test_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + score_threshold: score_threshold_np, + iou_threshold: iou_threshold_np + } + (indices_tf, _) = sess.run(selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + + def testNMS3From6Boxes(self): + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + with compat.forward_compatibility_horizon(2018, 8, 8): + # Three boxes are selected based on IOU. + boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + + max_output_size = 3 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.0, dtype=np.float32) + + with self.test_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + score_threshold: score_threshold_np, + iou_threshold: iou_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 3) + self.assertAllClose(indices_tf[:num_valid], [3, 0, 5]) + + def testNMS3Then2WithScoreThresh(self): + # Three boxes are selected based on IOU. + # One is filtered out by score threshold. + + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + with compat.forward_compatibility_horizon(2018, 8, 8): + boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + max_output_size = 3 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.4, dtype=np.float32) + + with self.test_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + iou_threshold: iou_threshold_np, + score_threshold: score_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 2) + self.assertAllClose(indices_tf[:num_valid], [3, 0]) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index f96483d23d..0609e22338 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -130,6 +130,7 @@ tf_kernel_library( "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/compiler/xla/client/lib:prng", + "//tensorflow/compiler/xla/client/lib:sorting", "//tensorflow/core:framework", "//tensorflow/core:image_ops_op_lib", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index 6e061ba278..33a73fe5fd 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -17,7 +17,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { namespace { @@ -311,5 +316,150 @@ class AdjustHueOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp); +class NonMaxSuppressionOp : public XlaOpKernel { + public: + explicit NonMaxSuppressionOp(OpKernelConstruction* context) + : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size", + &pad_to_max_output_size_)); + } + + void Compile(XlaOpKernelContext* context) override { + // TODO(b/111646731): Improve scalability of this op, using blocking. + int num_boxes_dim = 0; + int coords_dim = 1; + const TensorShape& boxes_shape = context->InputShape("boxes"); + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(boxes_shape), + errors::InvalidArgument("boxes must be 2-D, currently: ", + boxes_shape.DebugString())); + const int64 num_boxes = boxes_shape.dim_size(num_boxes_dim); + OP_REQUIRES(context, boxes_shape.dim_size(coords_dim) == 4, + errors::InvalidArgument("boxes must have 4 columns", + boxes_shape.DebugString())); + const TensorShape& scores_shape = context->InputShape("scores"); + OP_REQUIRES(context, TensorShapeUtils::IsVector(scores_shape), + errors::InvalidArgument("scores must be 1-D, currently: ", + scores_shape.DebugString())); + OP_REQUIRES( + context, scores_shape.dim_size(0) == num_boxes, + errors::InvalidArgument("scores size must equal number of boxes", + scores_shape.DebugString())); + OP_REQUIRES(context, pad_to_max_output_size_, + errors::InvalidArgument( + "XLA compilation requires pad_to_max_output_size == True")); + + xla::XlaOp boxes = context->Input("boxes"); + xla::XlaOp scores = context->Input("scores"); + int64 output_size; + OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &output_size)); + OP_REQUIRES( + context, output_size >= 0, + errors::InvalidArgument("Need output_size >= 0, got ", output_size)); + xla::XlaOp score_thresh = context->Input("score_threshold"); + xla::XlaOp iou_thresh = context->Input("iou_threshold"); + + xla::XlaBuilder* const builder = context->builder(); + + // Choose a more convenient layout. + xla::XlaOp boxes_t = xla::Transpose(boxes, {1, 0}); + coords_dim = 0; + num_boxes_dim = 1; + + // Shapes are henceforth [1, num_boxes]. + xla::XlaOp coord_y0 = xla::SliceInDim(boxes_t, + /*start_index=*/0, + /*limit_index=*/1, + /*stride=*/1, + /*dimno=*/coords_dim); + xla::XlaOp coord_x0 = xla::SliceInDim(boxes_t, + /*start_index=*/1, + /*limit_index=*/2, + /*stride=*/1, + /*dimno=*/coords_dim); + xla::XlaOp coord_y1 = xla::SliceInDim(boxes_t, + /*start_index=*/2, + /*limit_index=*/3, + /*stride=*/1, + /*dimno=*/coords_dim); + xla::XlaOp coord_x1 = xla::SliceInDim(boxes_t, + /*start_index=*/3, + /*limit_index=*/4, + /*stride=*/1, + /*dimno=*/coords_dim); + xla::XlaOp y1 = + xla::Select(xla::Le(coord_y0, coord_y1), coord_y0, coord_y1); + xla::XlaOp y2 = + xla::Select(xla::Le(coord_y0, coord_y1), coord_y1, coord_y0); + xla::XlaOp x1 = + xla::Select(xla::Le(coord_x0, coord_x1), coord_x0, coord_x1); + xla::XlaOp x2 = + xla::Select(xla::Le(coord_x0, coord_x1), coord_x1, coord_x0); + xla::XlaOp area = (y2 - y1) * (x2 - x1); + + // Transpose the 1xN tensors, instead of the NxN tensors. + xla::XlaOp y1_t = xla::Transpose(y1, {1, 0}); + xla::XlaOp y2_t = xla::Transpose(y2, {1, 0}); + xla::XlaOp x1_t = xla::Transpose(x1, {1, 0}); + xla::XlaOp x2_t = xla::Transpose(x2, {1, 0}); + xla::XlaOp area_t = xla::Transpose(area, {1, 0}); + + // Shapes are henceforth [num_boxes, num_boxes]. + xla::XlaOp i_xmin = xla::Max(x1, x1_t); + xla::XlaOp i_ymin = xla::Max(y1, y1_t); + xla::XlaOp i_xmax = xla::Min(x2, x2_t); + xla::XlaOp i_ymax = xla::Min(y2, y2_t); + auto square_zero = xla::ZerosLike(i_xmin); + + xla::XlaOp i_area = xla::Max(i_xmax - i_xmin, square_zero) * + xla::Max(i_ymax - i_ymin, square_zero); + xla::XlaOp u_area = area + area_t - i_area; + xla::XlaOp iou = i_area / u_area; + + xla::XlaOp iou_thresh_mask = xla::Gt(iou, iou_thresh + square_zero); + xla::XlaOp scores_2d = xla::Reshape(scores, {num_boxes, 1}); + xla::XlaOp score_cmp_mask = + xla::Gt(scores_2d, xla::Transpose(scores_2d, {1, 0})); + xla::XlaOp suppress = xla::And(iou_thresh_mask, score_cmp_mask); + + // Shapes are [num_boxes] after the reduce. + xla::XlaOp included_iou = xla::Not(xla::Reduce( + suppress, + /*init_value=*/xla::ConstantR0<bool>(builder, false), + /*computation=*/CreateScalarOrComputation(xla::PRED, builder), + /*dimensions_to_reduce=*/{0})); + xla::XlaOp included_score = + xla::Gt(scores, xla::Broadcast(score_thresh, {num_boxes})); + xla::XlaOp included = xla::And(included_iou, included_score); + xla::XlaOp neg_inf = + xla::Broadcast(xla::MinValue(builder, xla::F32), {num_boxes}); + xla::XlaOp scores_included = xla::Select(included, scores, neg_inf); + + xla::XlaOp ones_included = xla::Select( + included, + xla::Broadcast(xla::ConstantR0<int32>(builder, 1), {num_boxes}), + xla::Broadcast(xla::ConstantR0<int32>(builder, 0), {num_boxes})); + + // num_valid is scalar. + xla::XlaOp num_valid = xla::Reduce( + ones_included, + /*init_value=*/xla::ConstantR0<int>(builder, 0), + /*computation=*/CreateScalarAddComputation(xla::S32, builder), + /*dimensions_to_reduce=*/{0}); + + xla::XlaOp output_tuple = TopK(scores_included, output_size); + xla::XlaOp selected_indices = xla::GetTupleElement(output_tuple, 1); + + context->SetOutput(0, selected_indices); + context->SetOutput(1, num_valid); + } + + private: + bool pad_to_max_output_size_; +}; + +REGISTER_XLA_OP( + Name("NonMaxSuppressionV4").CompileTimeConstInput("max_output_size"), + NonMaxSuppressionOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index e73fb283b0..183879c760 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -47,31 +47,12 @@ class TopKOp : public XlaOpKernel { context, last_dim_size >= k, errors::InvalidArgument("input must have at least k columns. Had ", last_dim_size, ", needed ", k)); - - xla::XlaBuilder* const b = context->builder(); if (last_dim_size < k) { k = last_dim_size; } - const xla::XlaOp input = context->Input(0); - - xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, last_dim_size); - auto input_dims = input_shape.dim_sizes(); - std::vector<int64> broadcast_dims(input_dims.begin(), input_dims.end() - 1); - xla::XlaOp broadcast_s32 = xla::Broadcast(iota_s32, broadcast_dims); - xla::XlaOp sort_result = xla::Sort(xla::Neg(input), broadcast_s32); - - std::vector<int64> start_indices(input_shape.dims(), 0); - std::vector<int64> limit_indices(input_dims.begin(), input_dims.end()); - limit_indices[last_dim] = k; - std::vector<int64> strides(input_shape.dims(), 1); - - xla::XlaOp values = - xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0), start_indices, - limit_indices, strides)); - xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1), - start_indices, limit_indices, strides); - context->SetOutput(0, values); - context->SetOutput(1, indices); + xla::XlaOp output_tuple = TopK(context->Input(0), k); + context->SetOutput(0, xla::GetTupleElement(output_tuple, 0)); + context->SetOutput(1, xla::GetTupleElement(output_tuple, 1)); } private: diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 789daf4728..39d5582d19 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -137,6 +137,37 @@ cc_library( ) cc_library( + name = "sorting", + srcs = ["sorting.cc"], + hdrs = ["sorting.h"], + deps = [ + ":numeric", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + ], +) + +xla_test( + name = "sorting_test", + srcs = ["sorting_test.cc"], + blacklisted_backends = [ + "cpu", + "gpu", + ], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":sorting", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +cc_library( name = "testing", srcs = ["testing.cc"], hdrs = ["testing.h"], diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc new file mode 100644 index 0000000000..a904be259a --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/sorting.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" + +namespace xla { + +XlaOp TopK(XlaOp input, int64 k) { + XlaBuilder* const builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); + int last_dim = input_shape.dimensions_size() - 1; + int last_dim_size = input_shape.dimensions(last_dim); + + XlaOp iota_s32 = Iota(builder, S32, last_dim_size); + auto input_dims = input_shape.dimensions(); + std::vector<int64> broadcast_dims(input_dims.begin(), input_dims.end() - 1); + XlaOp broadcast_s32 = Broadcast(iota_s32, broadcast_dims); + XlaOp sort_result = Sort(Neg(input), broadcast_s32); + std::vector<int64> start_indices(input_shape.dimensions_size(), 0); + std::vector<int64> limit_indices(input_dims.begin(), input_dims.end()); + limit_indices[last_dim] = k; + std::vector<int64> strides(input_shape.dimensions_size(), 1); + + XlaOp values = Neg(Slice(GetTupleElement(sort_result, 0), start_indices, + limit_indices, strides)); + XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices, + limit_indices, strides); + return Tuple(builder, {values, indices}); + }); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/sorting.h b/tensorflow/compiler/xla/client/lib/sorting.h new file mode 100644 index 0000000000..404b4783c3 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/sorting.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_ + +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Returns a tuple composed of the top `k` values and corresponding indices in +// `input`. Output values are in descending order, from largest to smallest. +XlaOp TopK(XlaOp input, int64 k); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_ diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc new file mode 100644 index 0000000000..b6eee762a5 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace { + +using SortingTest = ClientLibraryTestBase; + +XLA_TEST_F(SortingTest, TopK3From8Values) { + XlaBuilder builder(TestName()); + auto x = + ConstantR1<float>(&builder, {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}); + xla::GetTupleElement(xla::TopK(x, 3), 0); + ComputeAndCompareR1<float>(&builder, {7.0, 6.0, 5.0}, {}); +} + +XLA_TEST_F(SortingTest, TopK3From8Indices) { + XlaBuilder builder(TestName()); + auto x_rev = + ConstantR1<float>(&builder, {7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0}); + xla::GetTupleElement(xla::TopK(x_rev, 3), 1); + ComputeAndCompareR1<int>(&builder, {0, 1, 2}, {}); +} + +XLA_TEST_F(SortingTest, TopKFullSort) { + XlaBuilder builder(TestName()); + const int kSize = 16; + std::mt19937 eng; + std::uniform_real_distribution<float> u_dist(0.0, 100.0); + auto gen = std::bind(u_dist, eng); + std::vector<float> inputs(kSize); + std::generate(inputs.begin(), inputs.end(), gen); + auto x = ConstantR1<float>(&builder, inputs); + xla::GetTupleElement(xla::TopK(x, kSize), 0); + + std::sort(inputs.begin(), inputs.end(), std::greater<float>()); + ComputeAndCompareR1<float>(&builder, inputs, {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionV4.pbtxt b/tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionV4.pbtxt new file mode 100644 index 0000000000..75df90f570 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionV4.pbtxt @@ -0,0 +1,78 @@ +op { + graph_op_name: "NonMaxSuppressionV4" + in_arg { + name: "boxes" + description: <<END +A 2-D float tensor of shape `[num_boxes, 4]`. +END + } + in_arg { + name: "scores" + description: <<END +A 1-D float tensor of shape `[num_boxes]` representing a single +score corresponding to each box (each row of boxes). +END + } + in_arg { + name: "max_output_size" + description: <<END +A scalar integer tensor representing the maximum number of +boxes to be selected by non max suppression. +END + } + in_arg { + name: "iou_threshold" + description: <<END +A 0-D float tensor representing the threshold for deciding whether +boxes overlap too much with respect to IOU. +END + } + in_arg { + name: "score_threshold" + description: <<END +A 0-D float tensor representing the threshold for deciding when to remove +boxes based on score. +END + } + attr { + name: "pad_to_max_output_size" + description: <<END +If true, the output `selected_indices` is padded to be of length +`max_output_size`. Defaults to false. +END + } + out_arg { + name: "selected_indices" + description: <<END +A 1-D integer tensor of shape `[M]` representing the selected +indices from the boxes tensor, where `M <= max_output_size`. +END + } + out_arg { + name: "valid_outputs" + description: <<END +A 0-D integer tensor representing the number of valid elements in +`selected_indices`, with the valid elements appearing first. +END + } + summary: "Greedily selects a subset of bounding boxes in descending order of score," + description: <<END +pruning away boxes that have high intersection-over-union (IOU) overlap +with previously selected boxes. Bounding boxes with score less than +`score_threshold` are removed. Bounding boxes are supplied as +[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any +diagonal pair of box corners and the coordinates can be provided as normalized +(i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm +is agnostic to where the origin is in the coordinate system and more +generally is invariant to orthogonal transformations and translations +of the coordinate system; thus translating or reflections of the coordinate +system result in the same boxes being selected by the algorithm. +The output of this operation is a set of integers indexing into the input +collection of bounding boxes representing the selected boxes. The bounding +box coordinates corresponding to the selected indices can then be obtained +using the `tf.gather operation`. For example: + selected_indices = tf.image.non_max_suppression_v2( + boxes, scores, max_output_size, iou_threshold, score_threshold) + selected_boxes = tf.gather(boxes, selected_indices) +END +} diff --git a/tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionV4.pbtxt b/tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionV4.pbtxt new file mode 100644 index 0000000000..be6caacd00 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionV4.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "NonMaxSuppressionV4" + visibility: HIDDEN +} diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc index f59843a07a..c7d0d4de0d 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cc @@ -121,10 +121,11 @@ static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn( std::placeholders::_1, std::placeholders::_2, threshold); } -void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& scores, - int num_boxes, const Tensor& max_output_size, - const float score_threshold, - std::function<bool(int, int)> suppress_check_fn) { +void DoNonMaxSuppressionOp( + OpKernelContext* context, const Tensor& scores, int num_boxes, + const Tensor& max_output_size, const float score_threshold, + const std::function<bool(int, int)>& suppress_check_fn, + bool pad_to_max_output_size = false, int* ptr_num_valid_outputs = nullptr) { const int output_size = std::min(max_output_size.scalar<int>()(), num_boxes); std::vector<float> scores_data(num_boxes); @@ -172,6 +173,15 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& scores, } } + int num_valid_outputs = selected.size(); + if (pad_to_max_output_size) { + selected.resize(output_size, 0); + selected_scores.resize(output_size, 0); + } + if (ptr_num_valid_outputs) { + *ptr_num_valid_outputs = num_valid_outputs; + } + // Allocate output tensors Tensor* output_indices = nullptr; TensorShape output_shape({static_cast<int>(selected.size())}); @@ -262,54 +272,106 @@ class NonMaxSuppressionV2Op : public OpKernel { } }; -template <typename Device> -class NonMaxSuppressionV3Op : public OpKernel { +class NonMaxSuppressionV3V4Base : public OpKernel { public: - explicit NonMaxSuppressionV3Op(OpKernelConstruction* context) + explicit NonMaxSuppressionV3V4Base(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { // boxes: [num_boxes, 4] - const Tensor& boxes = context->input(0); + boxes_ = context->input(0); // scores: [num_boxes] - const Tensor& scores = context->input(1); + scores_ = context->input(1); // max_output_size: scalar - const Tensor& max_output_size = context->input(2); + max_output_size_ = context->input(2); OP_REQUIRES( - context, TensorShapeUtils::IsScalar(max_output_size.shape()), + context, TensorShapeUtils::IsScalar(max_output_size_.shape()), errors::InvalidArgument("max_output_size must be 0-D, got shape ", - max_output_size.shape().DebugString())); + max_output_size_.shape().DebugString())); // iou_threshold: scalar const Tensor& iou_threshold = context->input(3); OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()), errors::InvalidArgument("iou_threshold must be 0-D, got shape ", iou_threshold.shape().DebugString())); - const float iou_threshold_val = iou_threshold.scalar<float>()(); - + iou_threshold_val_ = iou_threshold.scalar<float>()(); + OP_REQUIRES(context, iou_threshold_val_ >= 0 && iou_threshold_val_ <= 1, + errors::InvalidArgument("iou_threshold must be in [0, 1]")); // score_threshold: scalar const Tensor& score_threshold = context->input(4); OP_REQUIRES( context, TensorShapeUtils::IsScalar(score_threshold.shape()), errors::InvalidArgument("score_threshold must be 0-D, got shape ", score_threshold.shape().DebugString())); - const float score_threshold_val = score_threshold.scalar<float>()(); + score_threshold_val_ = score_threshold.scalar<float>()(); - OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1, - errors::InvalidArgument("iou_threshold must be in [0, 1]")); - int num_boxes = 0; - ParseAndCheckBoxSizes(context, boxes, &num_boxes); - CheckScoreSizes(context, num_boxes, scores); + num_boxes_ = 0; + ParseAndCheckBoxSizes(context, boxes_, &num_boxes_); + CheckScoreSizes(context, num_boxes_, scores_); if (!context->status().ok()) { return; } - auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_val); - DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size, - score_threshold_val, suppress_check_fn); + DoComputeAndPostProcess(context); + } + + protected: + virtual void DoComputeAndPostProcess(OpKernelContext* context) = 0; + + Tensor boxes_; + Tensor scores_; + Tensor max_output_size_; + int num_boxes_; + float iou_threshold_val_; + float score_threshold_val_; +}; + +template <typename Device> +class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base { + public: + explicit NonMaxSuppressionV3Op(OpKernelConstruction* context) + : NonMaxSuppressionV3V4Base(context) {} + + protected: + void DoComputeAndPostProcess(OpKernelContext* context) override { + auto suppress_check_fn = + CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_); + + DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_, + score_threshold_val_, suppress_check_fn); } }; template <typename Device> +class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base { + public: + explicit NonMaxSuppressionV4Op(OpKernelConstruction* context) + : NonMaxSuppressionV3V4Base(context) { + OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size", + &pad_to_max_output_size_)); + } + + protected: + void DoComputeAndPostProcess(OpKernelContext* context) override { + auto suppress_check_fn = + CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_); + int num_valid_outputs; + + DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_, + score_threshold_val_, suppress_check_fn, + pad_to_max_output_size_, &num_valid_outputs); + + // Allocate scalar output tensor for number of indices computed. + Tensor* num_outputs_t = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + 1, tensorflow::TensorShape{}, &num_outputs_t)); + num_outputs_t->scalar<int32>().setConstant(num_valid_outputs); + } + + private: + bool pad_to_max_output_size_; +}; + +template <typename Device> class NonMaxSuppressionWithOverlapsOp : public OpKernel { public: explicit NonMaxSuppressionWithOverlapsOp(OpKernelConstruction* context) @@ -365,6 +427,9 @@ REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3").Device(DEVICE_CPU), NonMaxSuppressionV3Op<CPUDevice>); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4").Device(DEVICE_CPU), + NonMaxSuppressionV4Op<CPUDevice>); + REGISTER_KERNEL_BUILDER( Name("NonMaxSuppressionWithOverlaps").Device(DEVICE_CPU), NonMaxSuppressionWithOverlapsOp<CPUDevice>); diff --git a/tensorflow/core/kernels/non_max_suppression_op_test.cc b/tensorflow/core/kernels/non_max_suppression_op_test.cc index 055161a35f..c321849f40 100644 --- a/tensorflow/core/kernels/non_max_suppression_op_test.cc +++ b/tensorflow/core/kernels/non_max_suppression_op_test.cc @@ -570,6 +570,61 @@ TEST_F(NonMaxSuppressionV3OpTest, TestEmptyInput) { } // +// NonMaxSuppressionV4Op Tests +// + +class NonMaxSuppressionV4OpTest : public OpsTestBase { + protected: + void MakeOp() { + TF_EXPECT_OK(NodeDefBuilder("non_max_suppression_op", "NonMaxSuppressionV4") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Attr("pad_to_max_output_size", true) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + } +}; + +TEST_F(NonMaxSuppressionV4OpTest, TestSelectFromThreeClustersPadFive) { + MakeOp(); + AddInputFromArray<float>( + TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); + AddInputFromArray<int>(TensorShape({}), {5}); + AddInputFromArray<float>(TensorShape({}), {.5f}); + AddInputFromArray<float>(TensorShape({}), {0.0f}); + TF_ASSERT_OK(RunOpKernel()); + + const auto expected_indices = test::AsTensor<int>({3, 0, 5, 0, 0}); + test::ExpectTensorEqual<int>(expected_indices, *GetOutput(0)); + Tensor expected_num_valid = test::AsScalar<int>(3); + test::ExpectTensorEqual<int>(expected_num_valid, *GetOutput(1)); +} + +TEST_F(NonMaxSuppressionV4OpTest, TestSelectFromThreeClustersPadFiveScoreThr) { + MakeOp(); + AddInputFromArray<float>( + TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); + AddInputFromArray<int>(TensorShape({}), {6}); + AddInputFromArray<float>(TensorShape({}), {.5f}); + AddInputFromArray<float>(TensorShape({}), {0.4f}); + TF_ASSERT_OK(RunOpKernel()); + + const auto expected_indices = test::AsTensor<int>({3, 0, 0, 0, 0, 0}); + test::ExpectTensorEqual<int>(expected_indices, *GetOutput(0)); + Tensor expected_num_valid = test::AsScalar<int>(2); + test::ExpectTensorEqual<int>(expected_num_valid, *GetOutput(1)); +} + +// // NonMaxSuppressionWithOverlapsOp Tests // diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index 50ced1ff73..31267f72b8 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -442,8 +442,9 @@ REGISTER_OP("DrawBoundingBoxes") if (c->ValueKnown(c->Dim(images, 3))) { int64 depth = c->Value(c->Dim(images, 3)); if (!(depth == 1 || depth == 3 || depth == 4)) { - return errors::InvalidArgument("Channel depth should be either 1 (GRY), " - "3 (RGB), or 4 (RGBA)"); + return errors::InvalidArgument( + "Channel depth should be either 1 (GRY), " + "3 (RGB), or 4 (RGBA)"); } } @@ -709,6 +710,40 @@ REGISTER_OP("NonMaxSuppressionV3") return Status::OK(); }); +REGISTER_OP("NonMaxSuppressionV4") + .Input("boxes: float") + .Input("scores: float") + .Input("max_output_size: int32") + .Input("iou_threshold: float") + .Input("score_threshold: float") + .Output("selected_indices: int32") + .Output("valid_outputs: int32") + .Attr("pad_to_max_output_size: bool = false") + .SetShapeFn([](InferenceContext* c) { + // Get inputs and validate ranks. + ShapeHandle boxes; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes)); + ShapeHandle scores; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores)); + ShapeHandle max_output_size; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size)); + ShapeHandle iou_threshold; + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold)); + ShapeHandle score_threshold; + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold)); + // The boxes is a 2-D float Tensor of shape [num_boxes, 4]. + DimensionHandle unused; + // The boxes[0] and scores[0] are both num_boxes. + TF_RETURN_IF_ERROR( + c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused)); + // The boxes[1] is 4. + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused)); + + c->set_output(0, c->Vector(c->UnknownDim())); + c->set_output(1, c->MakeShape({})); + return Status::OK(); + }); + REGISTER_OP("NonMaxSuppressionWithOverlaps") .Input("overlaps: float") .Input("scores: float") diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 9440bab9ee..855a4d0c33 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.compat import compat from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -2110,6 +2111,64 @@ def non_max_suppression(boxes, iou_threshold, score_threshold) +@tf_export('image.non_max_suppression_padded') +def non_max_suppression_padded(boxes, + scores, + max_output_size, + iou_threshold=0.5, + score_threshold=float('-inf'), + pad_to_max_output_size=False, + name=None): + """Greedily selects a subset of bounding boxes in descending order of score. + + Performs algorithmically equivalent operation to tf.image.non_max_suppression, + with the addition of an optional parameter which zero-pads the output to + be of size `max_output_size`. + The output of this operation is a tuple containing the set of integers + indexing into the input collection of bounding boxes representing the selected + boxes and the number of valid indices in the index set. The bounding box + coordinates corresponding to the selected indices can then be obtained using + the `tf.slice` and `tf.gather` operations. For example: + selected_indices_padded, num_valid = tf.image.non_max_suppression_padded( + boxes, scores, max_output_size, iou_threshold, + score_threshold, pad_to_max_output_size=True) + selected_indices = tf.slice( + selected_indices_padded, tf.constant([0]), num_valid) + selected_boxes = tf.gather(boxes, selected_indices) + + Args: + boxes: A 2-D float `Tensor` of shape `[num_boxes, 4]`. + scores: A 1-D float `Tensor` of shape `[num_boxes]` representing a single + score corresponding to each box (each row of boxes). + max_output_size: A scalar integer `Tensor` representing the maximum number + of boxes to be selected by non max suppression. + iou_threshold: A float representing the threshold for deciding whether boxes + overlap too much with respect to IOU. + score_threshold: A float representing the threshold for deciding when to + remove boxes based on score. + pad_to_max_output_size: bool. If True, size of `selected_indices` output + is padded to `max_output_size`. + name: A name for the operation (optional). + + Returns: + selected_indices: A 1-D integer `Tensor` of shape `[M]` representing the + selected indices from the boxes tensor, where `M <= max_output_size`. + valid_outputs: A scalar integer `Tensor` denoting how many elements in + `selected_indices` are valid. Valid elements occur first, then padding. + """ + with ops.name_scope(name, 'non_max_suppression_padded'): + iou_threshold = ops.convert_to_tensor(iou_threshold, name='iou_threshold') + score_threshold = ops.convert_to_tensor( + score_threshold, name='score_threshold') + if compat.forward_compatible(2018, 8, 7) or pad_to_max_output_size: + return gen_image_ops.non_max_suppression_v4( + boxes, scores, max_output_size, iou_threshold, score_threshold, + pad_to_max_output_size) + else: + return gen_image_ops.non_max_suppression_v3( + boxes, scores, max_output_size, iou_threshold, score_threshold) + + @tf_export('image.non_max_suppression_overlaps') def non_max_suppression_with_overlaps(overlaps, scores, diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt index 6ec3aba775..5c46dc5ee7 100644 --- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.image.pbtxt @@ -125,6 +125,10 @@ tf_module { argspec: "args=[\'overlaps\', \'scores\', \'max_output_size\', \'overlap_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'None\'], " } member_method { + name: "non_max_suppression_padded" + argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'pad_to_max_output_size\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'None\'], " + } + member_method { name: "pad_to_bounding_box" argspec: "args=[\'image\', \'offset_height\', \'offset_width\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None" } |