diff options
Diffstat (limited to 'tensorflow/core/kernels/non_max_suppression_op.cc')
-rw-r--r-- | tensorflow/core/kernels/non_max_suppression_op.cc | 36 |
1 files changed, 16 insertions, 20 deletions
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc index dc95f67ff0..9ffe71e031 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cc @@ -90,24 +90,20 @@ static inline float ComputeIOU(typename TTypes<float, 2>::ConstTensor boxes, return intersection_area / (area_i + area_j - intersection_area); } -void DoNonMaxSuppressionOp(OpKernelContext* context, - const Tensor& boxes, - const Tensor& scores, - const Tensor& max_output_size, +void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes, + const Tensor& scores, const Tensor& max_output_size, const float iou_threshold) { OP_REQUIRES(context, iou_threshold >= 0 && iou_threshold <= 1, - errors::InvalidArgument("iou_threshold must be in [0, 1]")); - + errors::InvalidArgument("iou_threshold must be in [0, 1]")); + int num_boxes = 0; ParseAndCheckBoxSizes(context, boxes, scores, &num_boxes); if (!context->status().ok()) { return; } - const int output_size = - std::min(max_output_size.scalar<int>()(), num_boxes); - typename TTypes<float, 2>::ConstTensor boxes_data = - boxes.tensor<float, 2>(); + const int output_size = std::min(max_output_size.scalar<int>()(), num_boxes); + typename TTypes<float, 2>::ConstTensor boxes_data = boxes.tensor<float, 2>(); std::vector<float> scores_data(num_boxes); std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin()); @@ -127,7 +123,7 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, for (int j = i + 1; j < num_boxes; ++j) { if (active[j]) { float iou = - ComputeIOU(boxes_data, sorted_indices[i], sorted_indices[j]); + ComputeIOU(boxes_data, sorted_indices[i], sorted_indices[j]); if (iou > iou_threshold) { active[j] = false; num_active--; @@ -145,7 +141,7 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, std::copy_n(selected.begin(), selected.size(), selected_indices_data.data()); } -} // namespace +} // namespace template <typename Device> class NonMaxSuppressionOp : public OpKernel { @@ -167,7 +163,8 @@ class NonMaxSuppressionOp : public OpKernel { errors::InvalidArgument("max_output_size must be 0-D, got shape ", max_output_size.shape().DebugString())); - DoNonMaxSuppressionOp(context, boxes, scores, max_output_size, iou_threshold_); + DoNonMaxSuppressionOp(context, boxes, scores, max_output_size, + iou_threshold_); } private: @@ -178,8 +175,7 @@ template <typename Device> class NonMaxSuppressionV2Op : public OpKernel { public: explicit NonMaxSuppressionV2Op(OpKernelConstruction* context) - : OpKernel(context) { - } + : OpKernel(context) {} void Compute(OpKernelContext* context) override { // boxes: [num_boxes, 4] @@ -194,14 +190,14 @@ class NonMaxSuppressionV2Op : public OpKernel { max_output_size.shape().DebugString())); // iou_threshold: scalar const Tensor& iou_threshold = context->input(3); - OP_REQUIRES( - context, TensorShapeUtils::IsScalar(iou_threshold.shape()), - errors::InvalidArgument("iou_threshold must be 0-D, got shape ", - iou_threshold.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()), + errors::InvalidArgument("iou_threshold must be 0-D, got shape ", + iou_threshold.shape().DebugString())); const float iou_threshold_val = iou_threshold.scalar<float>()(); - DoNonMaxSuppressionOp(context, boxes, scores, max_output_size, iou_threshold_val); + DoNonMaxSuppressionOp(context, boxes, scores, max_output_size, + iou_threshold_val); } }; |