aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Tayo Oguntebi <tayo@google.com>2018-07-30 17:13:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-30 17:16:55 -0700
commit8566d9e6fa7dbe3660339befe8b0a3344d24ef2b (patch)
tree54143734f9b19a9ead1a23784f75dcd2009d1452
parent8fee2e4b7c915d952332dc8cc9be7cfefea35162 (diff)
Adds a NonMaxSuppressionV4 op, with a corresponding TF2XLA implementation.
PiperOrigin-RevId: 206673787
-rw-r--r--tensorflow/compiler/tests/image_ops_test.py136
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_ops.cc150
-rw-r--r--tensorflow/compiler/tf2xla/kernels/topk_op.cc27
-rw-r--r--tensorflow/compiler/xla/client/lib/BUILD31
-rw-r--r--tensorflow/compiler/xla/client/lib/sorting.cc46
-rw-r--r--tensorflow/compiler/xla/client/lib/sorting.h31
-rw-r--r--tensorflow/compiler/xla/client/lib/sorting_test.cc60
-rw-r--r--tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionV4.pbtxt78
-rw-r--r--tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionV4.pbtxt4
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op.cc111
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op_test.cc55
-rw-r--r--tensorflow/core/ops/image_ops.cc39
-rw-r--r--tensorflow/python/ops/image_ops_impl.py59
-rw-r--r--tensorflow/tools/api/golden/tensorflow.image.pbtxt4
15 files changed, 784 insertions, 48 deletions
diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py
index 8b01ef96db..bf986ade06 100644
--- a/tensorflow/compiler/tests/image_ops_test.py
+++ b/tensorflow/compiler/tests/image_ops_test.py
@@ -26,6 +26,7 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.compiler.tests import xla_test
+from tensorflow.python.compat import compat
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -579,5 +580,140 @@ class ResizeBilinearTest(xla_test.XLATestCase):
large_tolerance=True)
+class NonMaxSuppressionTest(xla_test.XLATestCase):
+
+ def testNMS128From1024(self):
+ # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
+ if self.device in ["XLA_CPU", "XLA_GPU"]:
+ return
+
+ with compat.forward_compatibility_horizon(2018, 8, 8):
+ num_boxes = 1024
+ boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4")
+ scores_np = np.random.normal(0.5, 0.1, (num_boxes,)).astype("f4")
+
+ max_output_size = 128
+ iou_threshold_np = np.array(0.5, dtype=np.float32)
+ score_threshold_np = np.array(0.0, dtype=np.float32)
+
+ with self.test_session() as sess:
+ boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
+ scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
+ iou_threshold = array_ops.placeholder(iou_threshold_np.dtype,
+ iou_threshold_np.shape)
+ score_threshold = array_ops.placeholder(score_threshold_np.dtype,
+ score_threshold_np.shape)
+ with self.test_scope():
+ selected_indices = image_ops.non_max_suppression_padded(
+ boxes=boxes,
+ scores=scores,
+ max_output_size=max_output_size,
+ iou_threshold=iou_threshold,
+ score_threshold=score_threshold,
+ pad_to_max_output_size=True)
+ inputs_feed = {
+ boxes: boxes_np,
+ scores: scores_np,
+ score_threshold: score_threshold_np,
+ iou_threshold: iou_threshold_np
+ }
+ (indices_tf, _) = sess.run(selected_indices, feed_dict=inputs_feed)
+
+ self.assertEqual(indices_tf.size, max_output_size)
+
+ def testNMS3From6Boxes(self):
+ # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
+ if self.device in ["XLA_CPU", "XLA_GPU"]:
+ return
+
+ with compat.forward_compatibility_horizon(2018, 8, 8):
+ # Three boxes are selected based on IOU.
+ boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
+ [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
+ boxes_np = np.array(boxes_data, dtype=np.float32)
+
+ scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3]
+ scores_np = np.array(scores_data, dtype=np.float32)
+
+ max_output_size = 3
+ iou_threshold_np = np.array(0.5, dtype=np.float32)
+ score_threshold_np = np.array(0.0, dtype=np.float32)
+
+ with self.test_session() as sess:
+ boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
+ scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
+ iou_threshold = array_ops.placeholder(iou_threshold_np.dtype,
+ iou_threshold_np.shape)
+ score_threshold = array_ops.placeholder(score_threshold_np.dtype,
+ score_threshold_np.shape)
+ with self.test_scope():
+ selected_indices = image_ops.non_max_suppression_padded(
+ boxes=boxes,
+ scores=scores,
+ max_output_size=max_output_size,
+ iou_threshold=iou_threshold,
+ score_threshold=score_threshold,
+ pad_to_max_output_size=True)
+ inputs_feed = {
+ boxes: boxes_np,
+ scores: scores_np,
+ score_threshold: score_threshold_np,
+ iou_threshold: iou_threshold_np
+ }
+ (indices_tf, num_valid) = sess.run(
+ selected_indices, feed_dict=inputs_feed)
+
+ self.assertEqual(indices_tf.size, max_output_size)
+ self.assertEqual(num_valid, 3)
+ self.assertAllClose(indices_tf[:num_valid], [3, 0, 5])
+
+ def testNMS3Then2WithScoreThresh(self):
+ # Three boxes are selected based on IOU.
+ # One is filtered out by score threshold.
+
+ # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
+ if self.device in ["XLA_CPU", "XLA_GPU"]:
+ return
+
+ with compat.forward_compatibility_horizon(2018, 8, 8):
+ boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
+ [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
+ boxes_np = np.array(boxes_data, dtype=np.float32)
+
+ scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3]
+ scores_np = np.array(scores_data, dtype=np.float32)
+ max_output_size = 3
+ iou_threshold_np = np.array(0.5, dtype=np.float32)
+ score_threshold_np = np.array(0.4, dtype=np.float32)
+
+ with self.test_session() as sess:
+ boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
+ scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
+ iou_threshold = array_ops.placeholder(iou_threshold_np.dtype,
+ iou_threshold_np.shape)
+ score_threshold = array_ops.placeholder(score_threshold_np.dtype,
+ score_threshold_np.shape)
+ with self.test_scope():
+ selected_indices = image_ops.non_max_suppression_padded(
+ boxes=boxes,
+ scores=scores,
+ max_output_size=max_output_size,
+ iou_threshold=iou_threshold,
+ score_threshold=score_threshold,
+ pad_to_max_output_size=True)
+ inputs_feed = {
+ boxes: boxes_np,
+ scores: scores_np,
+ iou_threshold: iou_threshold_np,
+ score_threshold: score_threshold_np
+ }
+ (indices_tf, num_valid) = sess.run(
+ selected_indices, feed_dict=inputs_feed)
+
+ self.assertEqual(indices_tf.size, max_output_size)
+ self.assertEqual(num_valid, 2)
+ self.assertAllClose(indices_tf[:num_valid], [3, 0])
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index f96483d23d..0609e22338 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -130,6 +130,7 @@ tf_kernel_library(
"//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:numeric",
"//tensorflow/compiler/xla/client/lib:prng",
+ "//tensorflow/compiler/xla/client/lib:sorting",
"//tensorflow/core:framework",
"//tensorflow/core:image_ops_op_lib",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
index 6e061ba278..33a73fe5fd 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
@@ -17,7 +17,12 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/sorting.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
namespace {
@@ -311,5 +316,150 @@ class AdjustHueOp : public XlaOpKernel {
};
REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp);
+class NonMaxSuppressionOp : public XlaOpKernel {
+ public:
+ explicit NonMaxSuppressionOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size",
+ &pad_to_max_output_size_));
+ }
+
+ void Compile(XlaOpKernelContext* context) override {
+ // TODO(b/111646731): Improve scalability of this op, using blocking.
+ int num_boxes_dim = 0;
+ int coords_dim = 1;
+ const TensorShape& boxes_shape = context->InputShape("boxes");
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrix(boxes_shape),
+ errors::InvalidArgument("boxes must be 2-D, currently: ",
+ boxes_shape.DebugString()));
+ const int64 num_boxes = boxes_shape.dim_size(num_boxes_dim);
+ OP_REQUIRES(context, boxes_shape.dim_size(coords_dim) == 4,
+ errors::InvalidArgument("boxes must have 4 columns",
+ boxes_shape.DebugString()));
+ const TensorShape& scores_shape = context->InputShape("scores");
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(scores_shape),
+ errors::InvalidArgument("scores must be 1-D, currently: ",
+ scores_shape.DebugString()));
+ OP_REQUIRES(
+ context, scores_shape.dim_size(0) == num_boxes,
+ errors::InvalidArgument("scores size must equal number of boxes",
+ scores_shape.DebugString()));
+ OP_REQUIRES(context, pad_to_max_output_size_,
+ errors::InvalidArgument(
+ "XLA compilation requires pad_to_max_output_size == True"));
+
+ xla::XlaOp boxes = context->Input("boxes");
+ xla::XlaOp scores = context->Input("scores");
+ int64 output_size;
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &output_size));
+ OP_REQUIRES(
+ context, output_size >= 0,
+ errors::InvalidArgument("Need output_size >= 0, got ", output_size));
+ xla::XlaOp score_thresh = context->Input("score_threshold");
+ xla::XlaOp iou_thresh = context->Input("iou_threshold");
+
+ xla::XlaBuilder* const builder = context->builder();
+
+ // Choose a more convenient layout.
+ xla::XlaOp boxes_t = xla::Transpose(boxes, {1, 0});
+ coords_dim = 0;
+ num_boxes_dim = 1;
+
+ // Shapes are henceforth [1, num_boxes].
+ xla::XlaOp coord_y0 = xla::SliceInDim(boxes_t,
+ /*start_index=*/0,
+ /*limit_index=*/1,
+ /*stride=*/1,
+ /*dimno=*/coords_dim);
+ xla::XlaOp coord_x0 = xla::SliceInDim(boxes_t,
+ /*start_index=*/1,
+ /*limit_index=*/2,
+ /*stride=*/1,
+ /*dimno=*/coords_dim);
+ xla::XlaOp coord_y1 = xla::SliceInDim(boxes_t,
+ /*start_index=*/2,
+ /*limit_index=*/3,
+ /*stride=*/1,
+ /*dimno=*/coords_dim);
+ xla::XlaOp coord_x1 = xla::SliceInDim(boxes_t,
+ /*start_index=*/3,
+ /*limit_index=*/4,
+ /*stride=*/1,
+ /*dimno=*/coords_dim);
+ xla::XlaOp y1 =
+ xla::Select(xla::Le(coord_y0, coord_y1), coord_y0, coord_y1);
+ xla::XlaOp y2 =
+ xla::Select(xla::Le(coord_y0, coord_y1), coord_y1, coord_y0);
+ xla::XlaOp x1 =
+ xla::Select(xla::Le(coord_x0, coord_x1), coord_x0, coord_x1);
+ xla::XlaOp x2 =
+ xla::Select(xla::Le(coord_x0, coord_x1), coord_x1, coord_x0);
+ xla::XlaOp area = (y2 - y1) * (x2 - x1);
+
+ // Transpose the 1xN tensors, instead of the NxN tensors.
+ xla::XlaOp y1_t = xla::Transpose(y1, {1, 0});
+ xla::XlaOp y2_t = xla::Transpose(y2, {1, 0});
+ xla::XlaOp x1_t = xla::Transpose(x1, {1, 0});
+ xla::XlaOp x2_t = xla::Transpose(x2, {1, 0});
+ xla::XlaOp area_t = xla::Transpose(area, {1, 0});
+
+ // Shapes are henceforth [num_boxes, num_boxes].
+ xla::XlaOp i_xmin = xla::Max(x1, x1_t);
+ xla::XlaOp i_ymin = xla::Max(y1, y1_t);
+ xla::XlaOp i_xmax = xla::Min(x2, x2_t);
+ xla::XlaOp i_ymax = xla::Min(y2, y2_t);
+ auto square_zero = xla::ZerosLike(i_xmin);
+
+ xla::XlaOp i_area = xla::Max(i_xmax - i_xmin, square_zero) *
+ xla::Max(i_ymax - i_ymin, square_zero);
+ xla::XlaOp u_area = area + area_t - i_area;
+ xla::XlaOp iou = i_area / u_area;
+
+ xla::XlaOp iou_thresh_mask = xla::Gt(iou, iou_thresh + square_zero);
+ xla::XlaOp scores_2d = xla::Reshape(scores, {num_boxes, 1});
+ xla::XlaOp score_cmp_mask =
+ xla::Gt(scores_2d, xla::Transpose(scores_2d, {1, 0}));
+ xla::XlaOp suppress = xla::And(iou_thresh_mask, score_cmp_mask);
+
+ // Shapes are [num_boxes] after the reduce.
+ xla::XlaOp included_iou = xla::Not(xla::Reduce(
+ suppress,
+ /*init_value=*/xla::ConstantR0<bool>(builder, false),
+ /*computation=*/CreateScalarOrComputation(xla::PRED, builder),
+ /*dimensions_to_reduce=*/{0}));
+ xla::XlaOp included_score =
+ xla::Gt(scores, xla::Broadcast(score_thresh, {num_boxes}));
+ xla::XlaOp included = xla::And(included_iou, included_score);
+ xla::XlaOp neg_inf =
+ xla::Broadcast(xla::MinValue(builder, xla::F32), {num_boxes});
+ xla::XlaOp scores_included = xla::Select(included, scores, neg_inf);
+
+ xla::XlaOp ones_included = xla::Select(
+ included,
+ xla::Broadcast(xla::ConstantR0<int32>(builder, 1), {num_boxes}),
+ xla::Broadcast(xla::ConstantR0<int32>(builder, 0), {num_boxes}));
+
+ // num_valid is scalar.
+ xla::XlaOp num_valid = xla::Reduce(
+ ones_included,
+ /*init_value=*/xla::ConstantR0<int>(builder, 0),
+ /*computation=*/CreateScalarAddComputation(xla::S32, builder),
+ /*dimensions_to_reduce=*/{0});
+
+ xla::XlaOp output_tuple = TopK(scores_included, output_size);
+ xla::XlaOp selected_indices = xla::GetTupleElement(output_tuple, 1);
+
+ context->SetOutput(0, selected_indices);
+ context->SetOutput(1, num_valid);
+ }
+
+ private:
+ bool pad_to_max_output_size_;
+};
+
+REGISTER_XLA_OP(
+ Name("NonMaxSuppressionV4").CompileTimeConstInput("max_output_size"),
+ NonMaxSuppressionOp);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
index e73fb283b0..183879c760 100644
--- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
@@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
+#include "tensorflow/compiler/xla/client/lib/sorting.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
@@ -47,31 +47,12 @@ class TopKOp : public XlaOpKernel {
context, last_dim_size >= k,
errors::InvalidArgument("input must have at least k columns. Had ",
last_dim_size, ", needed ", k));
-
- xla::XlaBuilder* const b = context->builder();
if (last_dim_size < k) {
k = last_dim_size;
}
- const xla::XlaOp input = context->Input(0);
-
- xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, last_dim_size);
- auto input_dims = input_shape.dim_sizes();
- std::vector<int64> broadcast_dims(input_dims.begin(), input_dims.end() - 1);
- xla::XlaOp broadcast_s32 = xla::Broadcast(iota_s32, broadcast_dims);
- xla::XlaOp sort_result = xla::Sort(xla::Neg(input), broadcast_s32);
-
- std::vector<int64> start_indices(input_shape.dims(), 0);
- std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
- limit_indices[last_dim] = k;
- std::vector<int64> strides(input_shape.dims(), 1);
-
- xla::XlaOp values =
- xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0), start_indices,
- limit_indices, strides));
- xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1),
- start_indices, limit_indices, strides);
- context->SetOutput(0, values);
- context->SetOutput(1, indices);
+ xla::XlaOp output_tuple = TopK(context->Input(0), k);
+ context->SetOutput(0, xla::GetTupleElement(output_tuple, 0));
+ context->SetOutput(1, xla::GetTupleElement(output_tuple, 1));
}
private:
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index 789daf4728..39d5582d19 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -137,6 +137,37 @@ cc_library(
)
cc_library(
+ name = "sorting",
+ srcs = ["sorting.cc"],
+ hdrs = ["sorting.h"],
+ deps = [
+ ":numeric",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ ],
+)
+
+xla_test(
+ name = "sorting_test",
+ srcs = ["sorting_test.cc"],
+ blacklisted_backends = [
+ "cpu",
+ "gpu",
+ ],
+ tags = ["enable_for_xla_interpreter"],
+ deps = [
+ ":sorting",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ ],
+)
+
+cc_library(
name = "testing",
srcs = ["testing.cc"],
hdrs = ["testing.h"],
diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc
new file mode 100644
index 0000000000..a904be259a
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/sorting.cc
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/lib/sorting.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+
+namespace xla {
+
+XlaOp TopK(XlaOp input, int64 k) {
+ XlaBuilder* const builder = input.builder();
+ return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
+ int last_dim = input_shape.dimensions_size() - 1;
+ int last_dim_size = input_shape.dimensions(last_dim);
+
+ XlaOp iota_s32 = Iota(builder, S32, last_dim_size);
+ auto input_dims = input_shape.dimensions();
+ std::vector<int64> broadcast_dims(input_dims.begin(), input_dims.end() - 1);
+ XlaOp broadcast_s32 = Broadcast(iota_s32, broadcast_dims);
+ XlaOp sort_result = Sort(Neg(input), broadcast_s32);
+ std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
+ std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
+ limit_indices[last_dim] = k;
+ std::vector<int64> strides(input_shape.dimensions_size(), 1);
+
+ XlaOp values = Neg(Slice(GetTupleElement(sort_result, 0), start_indices,
+ limit_indices, strides));
+ XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices,
+ limit_indices, strides);
+ return Tuple(builder, {values, indices});
+ });
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/sorting.h b/tensorflow/compiler/xla/client/lib/sorting.h
new file mode 100644
index 0000000000..404b4783c3
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/sorting.h
@@ -0,0 +1,31 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_
+
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+
+// Returns a tuple composed of the top `k` values and corresponding indices in
+// `input`. Output values are in descending order, from largest to smallest.
+XlaOp TopK(XlaOp input, int64 k);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_
diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc
new file mode 100644
index 0000000000..b6eee762a5
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc
@@ -0,0 +1,60 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/lib/sorting.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+namespace {
+
+using SortingTest = ClientLibraryTestBase;
+
+XLA_TEST_F(SortingTest, TopK3From8Values) {
+ XlaBuilder builder(TestName());
+ auto x =
+ ConstantR1<float>(&builder, {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0});
+ xla::GetTupleElement(xla::TopK(x, 3), 0);
+ ComputeAndCompareR1<float>(&builder, {7.0, 6.0, 5.0}, {});
+}
+
+XLA_TEST_F(SortingTest, TopK3From8Indices) {
+ XlaBuilder builder(TestName());
+ auto x_rev =
+ ConstantR1<float>(&builder, {7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0});
+ xla::GetTupleElement(xla::TopK(x_rev, 3), 1);
+ ComputeAndCompareR1<int>(&builder, {0, 1, 2}, {});
+}
+
+XLA_TEST_F(SortingTest, TopKFullSort) {
+ XlaBuilder builder(TestName());
+ const int kSize = 16;
+ std::mt19937 eng;
+ std::uniform_real_distribution<float> u_dist(0.0, 100.0);
+ auto gen = std::bind(u_dist, eng);
+ std::vector<float> inputs(kSize);
+ std::generate(inputs.begin(), inputs.end(), gen);
+ auto x = ConstantR1<float>(&builder, inputs);
+ xla::GetTupleElement(xla::TopK(x, kSize), 0);
+
+ std::sort(inputs.begin(), inputs.end(), std::greater<float>());
+ ComputeAndCompareR1<float>(&builder, inputs, {});
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionV4.pbtxt b/tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionV4.pbtxt
new file mode 100644
index 0000000000..75df90f570
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionV4.pbtxt
@@ -0,0 +1,78 @@
+op {
+ graph_op_name: "NonMaxSuppressionV4"
+ in_arg {
+ name: "boxes"
+ description: <<END
+A 2-D float tensor of shape `[num_boxes, 4]`.
+END
+ }
+ in_arg {
+ name: "scores"
+ description: <<END
+A 1-D float tensor of shape `[num_boxes]` representing a single
+score corresponding to each box (each row of boxes).
+END
+ }
+ in_arg {
+ name: "max_output_size"
+ description: <<END
+A scalar integer tensor representing the maximum number of
+boxes to be selected by non max suppression.
+END
+ }
+ in_arg {
+ name: "iou_threshold"
+ description: <<END
+A 0-D float tensor representing the threshold for deciding whether
+boxes overlap too much with respect to IOU.
+END
+ }
+ in_arg {
+ name: "score_threshold"
+ description: <<END
+A 0-D float tensor representing the threshold for deciding when to remove
+boxes based on score.
+END
+ }
+ attr {
+ name: "pad_to_max_output_size"
+ description: <<END
+If true, the output `selected_indices` is padded to be of length
+`max_output_size`. Defaults to false.
+END
+ }
+ out_arg {
+ name: "selected_indices"
+ description: <<END
+A 1-D integer tensor of shape `[M]` representing the selected
+indices from the boxes tensor, where `M <= max_output_size`.
+END
+ }
+ out_arg {
+ name: "valid_outputs"
+ description: <<END
+A 0-D integer tensor representing the number of valid elements in
+`selected_indices`, with the valid elements appearing first.
+END
+ }
+ summary: "Greedily selects a subset of bounding boxes in descending order of score,"
+ description: <<END
+pruning away boxes that have high intersection-over-union (IOU) overlap
+with previously selected boxes. Bounding boxes with score less than
+`score_threshold` are removed. Bounding boxes are supplied as
+[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any
+diagonal pair of box corners and the coordinates can be provided as normalized
+(i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm
+is agnostic to where the origin is in the coordinate system and more
+generally is invariant to orthogonal transformations and translations
+of the coordinate system; thus translating or reflections of the coordinate
+system result in the same boxes being selected by the algorithm.
+The output of this operation is a set of integers indexing into the input
+collection of bounding boxes representing the selected boxes. The bounding
+box coordinates corresponding to the selected indices can then be obtained
+using the `tf.gather operation`. For example:
+ selected_indices = tf.image.non_max_suppression_v2(
+ boxes, scores, max_output_size, iou_threshold, score_threshold)
+ selected_boxes = tf.gather(boxes, selected_indices)
+END
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionV4.pbtxt b/tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionV4.pbtxt
new file mode 100644
index 0000000000..be6caacd00
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionV4.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "NonMaxSuppressionV4"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc
index f59843a07a..c7d0d4de0d 100644
--- a/tensorflow/core/kernels/non_max_suppression_op.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op.cc
@@ -121,10 +121,11 @@ static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn(
std::placeholders::_1, std::placeholders::_2, threshold);
}
-void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& scores,
- int num_boxes, const Tensor& max_output_size,
- const float score_threshold,
- std::function<bool(int, int)> suppress_check_fn) {
+void DoNonMaxSuppressionOp(
+ OpKernelContext* context, const Tensor& scores, int num_boxes,
+ const Tensor& max_output_size, const float score_threshold,
+ const std::function<bool(int, int)>& suppress_check_fn,
+ bool pad_to_max_output_size = false, int* ptr_num_valid_outputs = nullptr) {
const int output_size = std::min(max_output_size.scalar<int>()(), num_boxes);
std::vector<float> scores_data(num_boxes);
@@ -172,6 +173,15 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& scores,
}
}
+ int num_valid_outputs = selected.size();
+ if (pad_to_max_output_size) {
+ selected.resize(output_size, 0);
+ selected_scores.resize(output_size, 0);
+ }
+ if (ptr_num_valid_outputs) {
+ *ptr_num_valid_outputs = num_valid_outputs;
+ }
+
// Allocate output tensors
Tensor* output_indices = nullptr;
TensorShape output_shape({static_cast<int>(selected.size())});
@@ -262,54 +272,106 @@ class NonMaxSuppressionV2Op : public OpKernel {
}
};
-template <typename Device>
-class NonMaxSuppressionV3Op : public OpKernel {
+class NonMaxSuppressionV3V4Base : public OpKernel {
public:
- explicit NonMaxSuppressionV3Op(OpKernelConstruction* context)
+ explicit NonMaxSuppressionV3V4Base(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// boxes: [num_boxes, 4]
- const Tensor& boxes = context->input(0);
+ boxes_ = context->input(0);
// scores: [num_boxes]
- const Tensor& scores = context->input(1);
+ scores_ = context->input(1);
// max_output_size: scalar
- const Tensor& max_output_size = context->input(2);
+ max_output_size_ = context->input(2);
OP_REQUIRES(
- context, TensorShapeUtils::IsScalar(max_output_size.shape()),
+ context, TensorShapeUtils::IsScalar(max_output_size_.shape()),
errors::InvalidArgument("max_output_size must be 0-D, got shape ",
- max_output_size.shape().DebugString()));
+ max_output_size_.shape().DebugString()));
// iou_threshold: scalar
const Tensor& iou_threshold = context->input(3);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
iou_threshold.shape().DebugString()));
- const float iou_threshold_val = iou_threshold.scalar<float>()();
-
+ iou_threshold_val_ = iou_threshold.scalar<float>()();
+ OP_REQUIRES(context, iou_threshold_val_ >= 0 && iou_threshold_val_ <= 1,
+ errors::InvalidArgument("iou_threshold must be in [0, 1]"));
// score_threshold: scalar
const Tensor& score_threshold = context->input(4);
OP_REQUIRES(
context, TensorShapeUtils::IsScalar(score_threshold.shape()),
errors::InvalidArgument("score_threshold must be 0-D, got shape ",
score_threshold.shape().DebugString()));
- const float score_threshold_val = score_threshold.scalar<float>()();
+ score_threshold_val_ = score_threshold.scalar<float>()();
- OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1,
- errors::InvalidArgument("iou_threshold must be in [0, 1]"));
- int num_boxes = 0;
- ParseAndCheckBoxSizes(context, boxes, &num_boxes);
- CheckScoreSizes(context, num_boxes, scores);
+ num_boxes_ = 0;
+ ParseAndCheckBoxSizes(context, boxes_, &num_boxes_);
+ CheckScoreSizes(context, num_boxes_, scores_);
if (!context->status().ok()) {
return;
}
- auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_val);
- DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
- score_threshold_val, suppress_check_fn);
+ DoComputeAndPostProcess(context);
+ }
+
+ protected:
+ virtual void DoComputeAndPostProcess(OpKernelContext* context) = 0;
+
+ Tensor boxes_;
+ Tensor scores_;
+ Tensor max_output_size_;
+ int num_boxes_;
+ float iou_threshold_val_;
+ float score_threshold_val_;
+};
+
+template <typename Device>
+class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base {
+ public:
+ explicit NonMaxSuppressionV3Op(OpKernelConstruction* context)
+ : NonMaxSuppressionV3V4Base(context) {}
+
+ protected:
+ void DoComputeAndPostProcess(OpKernelContext* context) override {
+ auto suppress_check_fn =
+ CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_);
+
+ DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_,
+ score_threshold_val_, suppress_check_fn);
}
};
template <typename Device>
+class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base {
+ public:
+ explicit NonMaxSuppressionV4Op(OpKernelConstruction* context)
+ : NonMaxSuppressionV3V4Base(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size",
+ &pad_to_max_output_size_));
+ }
+
+ protected:
+ void DoComputeAndPostProcess(OpKernelContext* context) override {
+ auto suppress_check_fn =
+ CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_);
+ int num_valid_outputs;
+
+ DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_,
+ score_threshold_val_, suppress_check_fn,
+ pad_to_max_output_size_, &num_valid_outputs);
+
+ // Allocate scalar output tensor for number of indices computed.
+ Tensor* num_outputs_t = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 1, tensorflow::TensorShape{}, &num_outputs_t));
+ num_outputs_t->scalar<int32>().setConstant(num_valid_outputs);
+ }
+
+ private:
+ bool pad_to_max_output_size_;
+};
+
+template <typename Device>
class NonMaxSuppressionWithOverlapsOp : public OpKernel {
public:
explicit NonMaxSuppressionWithOverlapsOp(OpKernelConstruction* context)
@@ -365,6 +427,9 @@ REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3").Device(DEVICE_CPU),
NonMaxSuppressionV3Op<CPUDevice>);
+REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4").Device(DEVICE_CPU),
+ NonMaxSuppressionV4Op<CPUDevice>);
+
REGISTER_KERNEL_BUILDER(
Name("NonMaxSuppressionWithOverlaps").Device(DEVICE_CPU),
NonMaxSuppressionWithOverlapsOp<CPUDevice>);
diff --git a/tensorflow/core/kernels/non_max_suppression_op_test.cc b/tensorflow/core/kernels/non_max_suppression_op_test.cc
index 055161a35f..c321849f40 100644
--- a/tensorflow/core/kernels/non_max_suppression_op_test.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op_test.cc
@@ -570,6 +570,61 @@ TEST_F(NonMaxSuppressionV3OpTest, TestEmptyInput) {
}
//
+// NonMaxSuppressionV4Op Tests
+//
+
+class NonMaxSuppressionV4OpTest : public OpsTestBase {
+ protected:
+ void MakeOp() {
+ TF_EXPECT_OK(NodeDefBuilder("non_max_suppression_op", "NonMaxSuppressionV4")
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Attr("pad_to_max_output_size", true)
+ .Finalize(node_def()));
+ TF_EXPECT_OK(InitOp());
+ }
+};
+
+TEST_F(NonMaxSuppressionV4OpTest, TestSelectFromThreeClustersPadFive) {
+ MakeOp();
+ AddInputFromArray<float>(
+ TensorShape({6, 4}),
+ {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
+ 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
+ AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
+ AddInputFromArray<int>(TensorShape({}), {5});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ const auto expected_indices = test::AsTensor<int>({3, 0, 5, 0, 0});
+ test::ExpectTensorEqual<int>(expected_indices, *GetOutput(0));
+ Tensor expected_num_valid = test::AsScalar<int>(3);
+ test::ExpectTensorEqual<int>(expected_num_valid, *GetOutput(1));
+}
+
+TEST_F(NonMaxSuppressionV4OpTest, TestSelectFromThreeClustersPadFiveScoreThr) {
+ MakeOp();
+ AddInputFromArray<float>(
+ TensorShape({6, 4}),
+ {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
+ 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
+ AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
+ AddInputFromArray<int>(TensorShape({}), {6});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.4f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ const auto expected_indices = test::AsTensor<int>({3, 0, 0, 0, 0, 0});
+ test::ExpectTensorEqual<int>(expected_indices, *GetOutput(0));
+ Tensor expected_num_valid = test::AsScalar<int>(2);
+ test::ExpectTensorEqual<int>(expected_num_valid, *GetOutput(1));
+}
+
+//
// NonMaxSuppressionWithOverlapsOp Tests
//
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index 50ced1ff73..31267f72b8 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -442,8 +442,9 @@ REGISTER_OP("DrawBoundingBoxes")
if (c->ValueKnown(c->Dim(images, 3))) {
int64 depth = c->Value(c->Dim(images, 3));
if (!(depth == 1 || depth == 3 || depth == 4)) {
- return errors::InvalidArgument("Channel depth should be either 1 (GRY), "
- "3 (RGB), or 4 (RGBA)");
+ return errors::InvalidArgument(
+ "Channel depth should be either 1 (GRY), "
+ "3 (RGB), or 4 (RGBA)");
}
}
@@ -709,6 +710,40 @@ REGISTER_OP("NonMaxSuppressionV3")
return Status::OK();
});
+REGISTER_OP("NonMaxSuppressionV4")
+ .Input("boxes: float")
+ .Input("scores: float")
+ .Input("max_output_size: int32")
+ .Input("iou_threshold: float")
+ .Input("score_threshold: float")
+ .Output("selected_indices: int32")
+ .Output("valid_outputs: int32")
+ .Attr("pad_to_max_output_size: bool = false")
+ .SetShapeFn([](InferenceContext* c) {
+ // Get inputs and validate ranks.
+ ShapeHandle boxes;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
+ ShapeHandle scores;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
+ ShapeHandle max_output_size;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
+ ShapeHandle iou_threshold;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
+ ShapeHandle score_threshold;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
+ // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
+ DimensionHandle unused;
+ // The boxes[0] and scores[0] are both num_boxes.
+ TF_RETURN_IF_ERROR(
+ c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
+ // The boxes[1] is 4.
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
+
+ c->set_output(0, c->Vector(c->UnknownDim()));
+ c->set_output(1, c->MakeShape({}));
+ return Status::OK();
+ });
+
REGISTER_OP("NonMaxSuppressionWithOverlaps")
.Input("overlaps: float")
.Input("scores: float")
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 9440bab9ee..855a4d0c33 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -2110,6 +2111,64 @@ def non_max_suppression(boxes,
iou_threshold, score_threshold)
+@tf_export('image.non_max_suppression_padded')
+def non_max_suppression_padded(boxes,
+ scores,
+ max_output_size,
+ iou_threshold=0.5,
+ score_threshold=float('-inf'),
+ pad_to_max_output_size=False,
+ name=None):
+ """Greedily selects a subset of bounding boxes in descending order of score.
+
+ Performs algorithmically equivalent operation to tf.image.non_max_suppression,
+ with the addition of an optional parameter which zero-pads the output to
+ be of size `max_output_size`.
+ The output of this operation is a tuple containing the set of integers
+ indexing into the input collection of bounding boxes representing the selected
+ boxes and the number of valid indices in the index set. The bounding box
+ coordinates corresponding to the selected indices can then be obtained using
+ the `tf.slice` and `tf.gather` operations. For example:
+ selected_indices_padded, num_valid = tf.image.non_max_suppression_padded(
+ boxes, scores, max_output_size, iou_threshold,
+ score_threshold, pad_to_max_output_size=True)
+ selected_indices = tf.slice(
+ selected_indices_padded, tf.constant([0]), num_valid)
+ selected_boxes = tf.gather(boxes, selected_indices)
+
+ Args:
+ boxes: A 2-D float `Tensor` of shape `[num_boxes, 4]`.
+ scores: A 1-D float `Tensor` of shape `[num_boxes]` representing a single
+ score corresponding to each box (each row of boxes).
+ max_output_size: A scalar integer `Tensor` representing the maximum number
+ of boxes to be selected by non max suppression.
+ iou_threshold: A float representing the threshold for deciding whether boxes
+ overlap too much with respect to IOU.
+ score_threshold: A float representing the threshold for deciding when to
+ remove boxes based on score.
+ pad_to_max_output_size: bool. If True, size of `selected_indices` output
+ is padded to `max_output_size`.
+ name: A name for the operation (optional).
+
+ Returns:
+ selected_indices: A 1-D integer `Tensor` of shape `[M]` representing the
+ selected indices from the boxes tensor, where `M <= max_output_size`.
+ valid_outputs: A scalar integer `Tensor` denoting how many elements in
+ `selected_indices` are valid. Valid elements occur first, then padding.
+ """
+ with ops.name_scope(name, 'non_max_suppression_padded'):
+ iou_threshold = ops.convert_to_tensor(iou_threshold, name='iou_threshold')
+ score_threshold = ops.convert_to_tensor(
+ score_threshold, name='score_threshold')
+ if compat.forward_compatible(2018, 8, 7) or pad_to_max_output_size:
+ return gen_image_ops.non_max_suppression_v4(
+ boxes, scores, max_output_size, iou_threshold, score_threshold,
+ pad_to_max_output_size)
+ else:
+ return gen_image_ops.non_max_suppression_v3(
+ boxes, scores, max_output_size, iou_threshold, score_threshold)
+
+
@tf_export('image.non_max_suppression_overlaps')
def non_max_suppression_with_overlaps(overlaps,
scores,
diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
index 6ec3aba775..5c46dc5ee7 100644
--- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
@@ -125,6 +125,10 @@ tf_module {
argspec: "args=[\'overlaps\', \'scores\', \'max_output_size\', \'overlap_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'None\'], "
}
member_method {
+ name: "non_max_suppression_padded"
+ argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'pad_to_max_output_size\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'None\'], "
+ }
+ member_method {
name: "pad_to_bounding_box"
argspec: "args=[\'image\', \'offset_height\', \'offset_width\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None"
}