aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/non_max_suppression_op.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-18 18:02:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-18 18:06:21 -0700
commit53cb26d05a5c2080d8022124178b1cc43a30ffe5 (patch)
treeba11f5e078e8300e0a88f96f1029c549ade2a6c0 /tensorflow/core/kernels/non_max_suppression_op.cc
parentc311af00f2d72940c75ab0fc125ba2949858b2a9 (diff)
Merge changes from github.
END_PUBLIC --- Commit c2b8927f2 authored by Dandelion Man?<dandelion@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fix another d3v4 regression in the graph visualizer. PiperOrigin-RevId: 156343038 --- Commit 170f0b350 authored by Peter Hawkins<phawkins@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [TF:XLA] Add XLA implementation of ResourceStridedSliceAssign. PiperOrigin-RevId: 156341053 --- Commit 1390dd68f authored by Vijay Vasudevan<vrv@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: When Op Type is not registered, log the hostname of the machine that it is running on in the error message, since the message could be routed back during a failure on a remote binary, and it is hard to tell which machine it came from. Ideally, we'd somehow log the name of the binary running instead, but we don't have a function to get that right now. PiperOrigin-RevId: 156337679 --- Commit 9ca8a151b authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Internal change. PiperOrigin-RevId: 156335942 --- Commit 40255434c authored by Martin Wicke<wicke@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Deprecate contrib/learn/dataframe. To be removed June 15. PiperOrigin-RevId: 156333930 --- Commit 7f71b7fbe authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: BEGIN_PUBLIC Automated g4 rollback of changelist 156123287 PiperOrigin-RevId: 156503903
Diffstat (limited to 'tensorflow/core/kernels/non_max_suppression_op.cc')
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op.cc138
1 files changed, 89 insertions, 49 deletions
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc
index 4d4851c70c..9ffe71e031 100644
--- a/tensorflow/core/kernels/non_max_suppression_op.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op.cc
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
+namespace {
typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -89,6 +90,59 @@ static inline float ComputeIOU(typename TTypes<float, 2>::ConstTensor boxes,
return intersection_area / (area_i + area_j - intersection_area);
}
+void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes,
+ const Tensor& scores, const Tensor& max_output_size,
+ const float iou_threshold) {
+ OP_REQUIRES(context, iou_threshold >= 0 && iou_threshold <= 1,
+ errors::InvalidArgument("iou_threshold must be in [0, 1]"));
+
+ int num_boxes = 0;
+ ParseAndCheckBoxSizes(context, boxes, scores, &num_boxes);
+ if (!context->status().ok()) {
+ return;
+ }
+
+ const int output_size = std::min(max_output_size.scalar<int>()(), num_boxes);
+ typename TTypes<float, 2>::ConstTensor boxes_data = boxes.tensor<float, 2>();
+
+ std::vector<float> scores_data(num_boxes);
+ std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin());
+ std::vector<int> sorted_indices;
+ DecreasingArgSort(scores_data, &sorted_indices);
+
+ std::vector<bool> active(num_boxes, true);
+ std::vector<int> selected;
+ int num_active = active.size();
+ for (int i = 0; i < num_boxes; ++i) {
+ if (num_active == 0 || selected.size() >= output_size) break;
+ if (active[i]) {
+ selected.push_back(sorted_indices[i]);
+ } else {
+ continue;
+ }
+ for (int j = i + 1; j < num_boxes; ++j) {
+ if (active[j]) {
+ float iou =
+ ComputeIOU(boxes_data, sorted_indices[i], sorted_indices[j]);
+ if (iou > iou_threshold) {
+ active[j] = false;
+ num_active--;
+ }
+ }
+ }
+ }
+
+ // Allocate output tensor
+ Tensor* output = nullptr;
+ TensorShape output_shape({static_cast<int>(selected.size())});
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+ typename TTypes<int, 1>::Tensor selected_indices_data =
+ output->tensor<int, 1>();
+ std::copy_n(selected.begin(), selected.size(), selected_indices_data.data());
+}
+
+} // namespace
+
template <typename Device>
class NonMaxSuppressionOp : public OpKernel {
public:
@@ -98,9 +152,6 @@ class NonMaxSuppressionOp : public OpKernel {
}
void Compute(OpKernelContext* context) override {
- OP_REQUIRES(context, iou_threshold_ >= 0 && iou_threshold_ <= 1,
- errors::InvalidArgument("iou_threshold must be in [0, 1]"));
-
// boxes: [num_boxes, 4]
const Tensor& boxes = context->input(0);
// scores: [num_boxes]
@@ -112,59 +163,48 @@ class NonMaxSuppressionOp : public OpKernel {
errors::InvalidArgument("max_output_size must be 0-D, got shape ",
max_output_size.shape().DebugString()));
- int num_boxes = 0;
- ParseAndCheckBoxSizes(context, boxes, scores, &num_boxes);
- if (!context->status().ok()) {
- return;
- }
-
- const int output_size =
- std::min(max_output_size.scalar<int>()(), num_boxes);
- typename TTypes<float, 2>::ConstTensor boxes_data =
- boxes.tensor<float, 2>();
-
- std::vector<float> scores_data(num_boxes);
- std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin());
- std::vector<int> sorted_indices;
- DecreasingArgSort(scores_data, &sorted_indices);
-
- std::vector<bool> active(num_boxes, true);
- std::vector<int> selected;
- int num_active = active.size();
- for (int i = 0; i < num_boxes; ++i) {
- if (num_active == 0 || selected.size() >= output_size) break;
- if (active[i]) {
- selected.push_back(sorted_indices[i]);
- } else {
- continue;
- }
- for (int j = i + 1; j < num_boxes; ++j) {
- if (active[j]) {
- float iou =
- ComputeIOU(boxes_data, sorted_indices[i], sorted_indices[j]);
- if (iou > iou_threshold_) {
- active[j] = false;
- num_active--;
- }
- }
- }
- }
-
- // Allocate output tensor
- Tensor* output = nullptr;
- TensorShape output_shape({static_cast<int>(selected.size())});
- OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
- typename TTypes<int, 1>::Tensor selected_indices_data =
- output->tensor<int, 1>();
- std::copy_n(selected.begin(), selected.size(),
- selected_indices_data.data());
+ DoNonMaxSuppressionOp(context, boxes, scores, max_output_size,
+ iou_threshold_);
}
private:
float iou_threshold_;
};
+template <typename Device>
+class NonMaxSuppressionV2Op : public OpKernel {
+ public:
+ explicit NonMaxSuppressionV2Op(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ // boxes: [num_boxes, 4]
+ const Tensor& boxes = context->input(0);
+ // scores: [num_boxes]
+ const Tensor& scores = context->input(1);
+ // max_output_size: scalar
+ const Tensor& max_output_size = context->input(2);
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsScalar(max_output_size.shape()),
+ errors::InvalidArgument("max_output_size must be 0-D, got shape ",
+ max_output_size.shape().DebugString()));
+ // iou_threshold: scalar
+ const Tensor& iou_threshold = context->input(3);
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
+ errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
+ iou_threshold.shape().DebugString()));
+
+ const float iou_threshold_val = iou_threshold.scalar<float>()();
+
+ DoNonMaxSuppressionOp(context, boxes, scores, max_output_size,
+ iou_threshold_val);
+ }
+};
+
REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU),
NonMaxSuppressionOp<CPUDevice>);
+REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU),
+ NonMaxSuppressionV2Op<CPUDevice>);
+
} // namespace tensorflow