aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/non_max_suppression_op.cc
diff options
context:
space:
mode:
authorGravatar Tayo Oguntebi <tayo@google.com>2018-07-30 17:13:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-30 17:16:55 -0700
commit8566d9e6fa7dbe3660339befe8b0a3344d24ef2b (patch)
tree54143734f9b19a9ead1a23784f75dcd2009d1452 /tensorflow/core/kernels/non_max_suppression_op.cc
parent8fee2e4b7c915d952332dc8cc9be7cfefea35162 (diff)
Adds a NonMaxSuppressionV4 op, with a corresponding TF2XLA implementation.
PiperOrigin-RevId: 206673787
Diffstat (limited to 'tensorflow/core/kernels/non_max_suppression_op.cc')
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op.cc111
1 files changed, 88 insertions, 23 deletions
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc
index f59843a07a..c7d0d4de0d 100644
--- a/tensorflow/core/kernels/non_max_suppression_op.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op.cc
@@ -121,10 +121,11 @@ static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn(
std::placeholders::_1, std::placeholders::_2, threshold);
}
-void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& scores,
- int num_boxes, const Tensor& max_output_size,
- const float score_threshold,
- std::function<bool(int, int)> suppress_check_fn) {
+void DoNonMaxSuppressionOp(
+ OpKernelContext* context, const Tensor& scores, int num_boxes,
+ const Tensor& max_output_size, const float score_threshold,
+ const std::function<bool(int, int)>& suppress_check_fn,
+ bool pad_to_max_output_size = false, int* ptr_num_valid_outputs = nullptr) {
const int output_size = std::min(max_output_size.scalar<int>()(), num_boxes);
std::vector<float> scores_data(num_boxes);
@@ -172,6 +173,15 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& scores,
}
}
+ int num_valid_outputs = selected.size();
+ if (pad_to_max_output_size) {
+ selected.resize(output_size, 0);
+ selected_scores.resize(output_size, 0);
+ }
+ if (ptr_num_valid_outputs) {
+ *ptr_num_valid_outputs = num_valid_outputs;
+ }
+
// Allocate output tensors
Tensor* output_indices = nullptr;
TensorShape output_shape({static_cast<int>(selected.size())});
@@ -262,54 +272,106 @@ class NonMaxSuppressionV2Op : public OpKernel {
}
};
-template <typename Device>
-class NonMaxSuppressionV3Op : public OpKernel {
+class NonMaxSuppressionV3V4Base : public OpKernel {
public:
- explicit NonMaxSuppressionV3Op(OpKernelConstruction* context)
+ explicit NonMaxSuppressionV3V4Base(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// boxes: [num_boxes, 4]
- const Tensor& boxes = context->input(0);
+ boxes_ = context->input(0);
// scores: [num_boxes]
- const Tensor& scores = context->input(1);
+ scores_ = context->input(1);
// max_output_size: scalar
- const Tensor& max_output_size = context->input(2);
+ max_output_size_ = context->input(2);
OP_REQUIRES(
- context, TensorShapeUtils::IsScalar(max_output_size.shape()),
+ context, TensorShapeUtils::IsScalar(max_output_size_.shape()),
errors::InvalidArgument("max_output_size must be 0-D, got shape ",
- max_output_size.shape().DebugString()));
+ max_output_size_.shape().DebugString()));
// iou_threshold: scalar
const Tensor& iou_threshold = context->input(3);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
iou_threshold.shape().DebugString()));
- const float iou_threshold_val = iou_threshold.scalar<float>()();
-
+ iou_threshold_val_ = iou_threshold.scalar<float>()();
+ OP_REQUIRES(context, iou_threshold_val_ >= 0 && iou_threshold_val_ <= 1,
+ errors::InvalidArgument("iou_threshold must be in [0, 1]"));
// score_threshold: scalar
const Tensor& score_threshold = context->input(4);
OP_REQUIRES(
context, TensorShapeUtils::IsScalar(score_threshold.shape()),
errors::InvalidArgument("score_threshold must be 0-D, got shape ",
score_threshold.shape().DebugString()));
- const float score_threshold_val = score_threshold.scalar<float>()();
+ score_threshold_val_ = score_threshold.scalar<float>()();
- OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1,
- errors::InvalidArgument("iou_threshold must be in [0, 1]"));
- int num_boxes = 0;
- ParseAndCheckBoxSizes(context, boxes, &num_boxes);
- CheckScoreSizes(context, num_boxes, scores);
+ num_boxes_ = 0;
+ ParseAndCheckBoxSizes(context, boxes_, &num_boxes_);
+ CheckScoreSizes(context, num_boxes_, scores_);
if (!context->status().ok()) {
return;
}
- auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_val);
- DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
- score_threshold_val, suppress_check_fn);
+ DoComputeAndPostProcess(context);
+ }
+
+ protected:
+ virtual void DoComputeAndPostProcess(OpKernelContext* context) = 0;
+
+ Tensor boxes_;
+ Tensor scores_;
+ Tensor max_output_size_;
+ int num_boxes_;
+ float iou_threshold_val_;
+ float score_threshold_val_;
+};
+
+template <typename Device>
+class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base {
+ public:
+ explicit NonMaxSuppressionV3Op(OpKernelConstruction* context)
+ : NonMaxSuppressionV3V4Base(context) {}
+
+ protected:
+ void DoComputeAndPostProcess(OpKernelContext* context) override {
+ auto suppress_check_fn =
+ CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_);
+
+ DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_,
+ score_threshold_val_, suppress_check_fn);
}
};
template <typename Device>
+class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base {
+ public:
+ explicit NonMaxSuppressionV4Op(OpKernelConstruction* context)
+ : NonMaxSuppressionV3V4Base(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size",
+ &pad_to_max_output_size_));
+ }
+
+ protected:
+ void DoComputeAndPostProcess(OpKernelContext* context) override {
+ auto suppress_check_fn =
+ CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_);
+ int num_valid_outputs;
+
+ DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_,
+ score_threshold_val_, suppress_check_fn,
+ pad_to_max_output_size_, &num_valid_outputs);
+
+ // Allocate scalar output tensor for number of indices computed.
+ Tensor* num_outputs_t = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 1, tensorflow::TensorShape{}, &num_outputs_t));
+ num_outputs_t->scalar<int32>().setConstant(num_valid_outputs);
+ }
+
+ private:
+ bool pad_to_max_output_size_;
+};
+
+template <typename Device>
class NonMaxSuppressionWithOverlapsOp : public OpKernel {
public:
explicit NonMaxSuppressionWithOverlapsOp(OpKernelConstruction* context)
@@ -365,6 +427,9 @@ REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3").Device(DEVICE_CPU),
NonMaxSuppressionV3Op<CPUDevice>);
+REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4").Device(DEVICE_CPU),
+ NonMaxSuppressionV4Op<CPUDevice>);
+
REGISTER_KERNEL_BUILDER(
Name("NonMaxSuppressionWithOverlaps").Device(DEVICE_CPU),
NonMaxSuppressionWithOverlapsOp<CPUDevice>);