diff options
author | Tayo Oguntebi <tayo@google.com> | 2018-07-30 17:13:09 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-30 17:16:55 -0700 |
commit | 8566d9e6fa7dbe3660339befe8b0a3344d24ef2b (patch) | |
tree | 54143734f9b19a9ead1a23784f75dcd2009d1452 /tensorflow/core/kernels/non_max_suppression_op.cc | |
parent | 8fee2e4b7c915d952332dc8cc9be7cfefea35162 (diff) |
Adds a NonMaxSuppressionV4 op, with a corresponding TF2XLA implementation.
PiperOrigin-RevId: 206673787
Diffstat (limited to 'tensorflow/core/kernels/non_max_suppression_op.cc')
-rw-r--r-- | tensorflow/core/kernels/non_max_suppression_op.cc | 111 |
1 files changed, 88 insertions, 23 deletions
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc index f59843a07a..c7d0d4de0d 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cc @@ -121,10 +121,11 @@ static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn( std::placeholders::_1, std::placeholders::_2, threshold); } -void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& scores, - int num_boxes, const Tensor& max_output_size, - const float score_threshold, - std::function<bool(int, int)> suppress_check_fn) { +void DoNonMaxSuppressionOp( + OpKernelContext* context, const Tensor& scores, int num_boxes, + const Tensor& max_output_size, const float score_threshold, + const std::function<bool(int, int)>& suppress_check_fn, + bool pad_to_max_output_size = false, int* ptr_num_valid_outputs = nullptr) { const int output_size = std::min(max_output_size.scalar<int>()(), num_boxes); std::vector<float> scores_data(num_boxes); @@ -172,6 +173,15 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& scores, } } + int num_valid_outputs = selected.size(); + if (pad_to_max_output_size) { + selected.resize(output_size, 0); + selected_scores.resize(output_size, 0); + } + if (ptr_num_valid_outputs) { + *ptr_num_valid_outputs = num_valid_outputs; + } + // Allocate output tensors Tensor* output_indices = nullptr; TensorShape output_shape({static_cast<int>(selected.size())}); @@ -262,54 +272,106 @@ class NonMaxSuppressionV2Op : public OpKernel { } }; -template <typename Device> -class NonMaxSuppressionV3Op : public OpKernel { +class NonMaxSuppressionV3V4Base : public OpKernel { public: - explicit NonMaxSuppressionV3Op(OpKernelConstruction* context) + explicit NonMaxSuppressionV3V4Base(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { // boxes: [num_boxes, 4] - const Tensor& boxes = context->input(0); + boxes_ = context->input(0); // scores: [num_boxes] - const Tensor& scores = context->input(1); + scores_ = context->input(1); // max_output_size: scalar - const Tensor& max_output_size = context->input(2); + max_output_size_ = context->input(2); OP_REQUIRES( - context, TensorShapeUtils::IsScalar(max_output_size.shape()), + context, TensorShapeUtils::IsScalar(max_output_size_.shape()), errors::InvalidArgument("max_output_size must be 0-D, got shape ", - max_output_size.shape().DebugString())); + max_output_size_.shape().DebugString())); // iou_threshold: scalar const Tensor& iou_threshold = context->input(3); OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()), errors::InvalidArgument("iou_threshold must be 0-D, got shape ", iou_threshold.shape().DebugString())); - const float iou_threshold_val = iou_threshold.scalar<float>()(); - + iou_threshold_val_ = iou_threshold.scalar<float>()(); + OP_REQUIRES(context, iou_threshold_val_ >= 0 && iou_threshold_val_ <= 1, + errors::InvalidArgument("iou_threshold must be in [0, 1]")); // score_threshold: scalar const Tensor& score_threshold = context->input(4); OP_REQUIRES( context, TensorShapeUtils::IsScalar(score_threshold.shape()), errors::InvalidArgument("score_threshold must be 0-D, got shape ", score_threshold.shape().DebugString())); - const float score_threshold_val = score_threshold.scalar<float>()(); + score_threshold_val_ = score_threshold.scalar<float>()(); - OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1, - errors::InvalidArgument("iou_threshold must be in [0, 1]")); - int num_boxes = 0; - ParseAndCheckBoxSizes(context, boxes, &num_boxes); - CheckScoreSizes(context, num_boxes, scores); + num_boxes_ = 0; + ParseAndCheckBoxSizes(context, boxes_, &num_boxes_); + CheckScoreSizes(context, num_boxes_, scores_); if (!context->status().ok()) { return; } - auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_val); - DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size, - score_threshold_val, suppress_check_fn); + DoComputeAndPostProcess(context); + } + + protected: + virtual void DoComputeAndPostProcess(OpKernelContext* context) = 0; + + Tensor boxes_; + Tensor scores_; + Tensor max_output_size_; + int num_boxes_; + float iou_threshold_val_; + float score_threshold_val_; +}; + +template <typename Device> +class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base { + public: + explicit NonMaxSuppressionV3Op(OpKernelConstruction* context) + : NonMaxSuppressionV3V4Base(context) {} + + protected: + void DoComputeAndPostProcess(OpKernelContext* context) override { + auto suppress_check_fn = + CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_); + + DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_, + score_threshold_val_, suppress_check_fn); } }; template <typename Device> +class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base { + public: + explicit NonMaxSuppressionV4Op(OpKernelConstruction* context) + : NonMaxSuppressionV3V4Base(context) { + OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size", + &pad_to_max_output_size_)); + } + + protected: + void DoComputeAndPostProcess(OpKernelContext* context) override { + auto suppress_check_fn = + CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_); + int num_valid_outputs; + + DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_, + score_threshold_val_, suppress_check_fn, + pad_to_max_output_size_, &num_valid_outputs); + + // Allocate scalar output tensor for number of indices computed. + Tensor* num_outputs_t = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + 1, tensorflow::TensorShape{}, &num_outputs_t)); + num_outputs_t->scalar<int32>().setConstant(num_valid_outputs); + } + + private: + bool pad_to_max_output_size_; +}; + +template <typename Device> class NonMaxSuppressionWithOverlapsOp : public OpKernel { public: explicit NonMaxSuppressionWithOverlapsOp(OpKernelConstruction* context) @@ -365,6 +427,9 @@ REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3").Device(DEVICE_CPU), NonMaxSuppressionV3Op<CPUDevice>); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4").Device(DEVICE_CPU), + NonMaxSuppressionV4Op<CPUDevice>); + REGISTER_KERNEL_BUILDER( Name("NonMaxSuppressionWithOverlaps").Device(DEVICE_CPU), NonMaxSuppressionWithOverlapsOp<CPUDevice>); |