aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/non_max_suppression_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/non_max_suppression_op.cc')
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op.cc36
1 files changed, 16 insertions, 20 deletions
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc
index dc95f67ff0..9ffe71e031 100644
--- a/tensorflow/core/kernels/non_max_suppression_op.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op.cc
@@ -90,24 +90,20 @@ static inline float ComputeIOU(typename TTypes<float, 2>::ConstTensor boxes,
return intersection_area / (area_i + area_j - intersection_area);
}
-void DoNonMaxSuppressionOp(OpKernelContext* context,
- const Tensor& boxes,
- const Tensor& scores,
- const Tensor& max_output_size,
+void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes,
+ const Tensor& scores, const Tensor& max_output_size,
const float iou_threshold) {
OP_REQUIRES(context, iou_threshold >= 0 && iou_threshold <= 1,
- errors::InvalidArgument("iou_threshold must be in [0, 1]"));
-
+ errors::InvalidArgument("iou_threshold must be in [0, 1]"));
+
int num_boxes = 0;
ParseAndCheckBoxSizes(context, boxes, scores, &num_boxes);
if (!context->status().ok()) {
return;
}
- const int output_size =
- std::min(max_output_size.scalar<int>()(), num_boxes);
- typename TTypes<float, 2>::ConstTensor boxes_data =
- boxes.tensor<float, 2>();
+ const int output_size = std::min(max_output_size.scalar<int>()(), num_boxes);
+ typename TTypes<float, 2>::ConstTensor boxes_data = boxes.tensor<float, 2>();
std::vector<float> scores_data(num_boxes);
std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin());
@@ -127,7 +123,7 @@ void DoNonMaxSuppressionOp(OpKernelContext* context,
for (int j = i + 1; j < num_boxes; ++j) {
if (active[j]) {
float iou =
- ComputeIOU(boxes_data, sorted_indices[i], sorted_indices[j]);
+ ComputeIOU(boxes_data, sorted_indices[i], sorted_indices[j]);
if (iou > iou_threshold) {
active[j] = false;
num_active--;
@@ -145,7 +141,7 @@ void DoNonMaxSuppressionOp(OpKernelContext* context,
std::copy_n(selected.begin(), selected.size(), selected_indices_data.data());
}
-} // namespace
+} // namespace
template <typename Device>
class NonMaxSuppressionOp : public OpKernel {
@@ -167,7 +163,8 @@ class NonMaxSuppressionOp : public OpKernel {
errors::InvalidArgument("max_output_size must be 0-D, got shape ",
max_output_size.shape().DebugString()));
- DoNonMaxSuppressionOp(context, boxes, scores, max_output_size, iou_threshold_);
+ DoNonMaxSuppressionOp(context, boxes, scores, max_output_size,
+ iou_threshold_);
}
private:
@@ -178,8 +175,7 @@ template <typename Device>
class NonMaxSuppressionV2Op : public OpKernel {
public:
explicit NonMaxSuppressionV2Op(OpKernelConstruction* context)
- : OpKernel(context) {
- }
+ : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// boxes: [num_boxes, 4]
@@ -194,14 +190,14 @@ class NonMaxSuppressionV2Op : public OpKernel {
max_output_size.shape().DebugString()));
// iou_threshold: scalar
const Tensor& iou_threshold = context->input(3);
- OP_REQUIRES(
- context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
- errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
- iou_threshold.shape().DebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
+ errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
+ iou_threshold.shape().DebugString()));
const float iou_threshold_val = iou_threshold.scalar<float>()();
- DoNonMaxSuppressionOp(context, boxes, scores, max_output_size, iou_threshold_val);
+ DoNonMaxSuppressionOp(context, boxes, scores, max_output_size,
+ iou_threshold_val);
}
};