diff options
Diffstat (limited to 'tensorflow/core/ops/image_ops.cc')
-rw-r--r-- | tensorflow/core/ops/image_ops.cc | 77 |
1 files changed, 34 insertions, 43 deletions
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index 81f324a3ef..11ca0bd259 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -108,6 +108,29 @@ Status ColorspaceShapeFn(InferenceContext* c) { return Status::OK(); } +Status NMSShapeFn(InferenceContext* c) { + // Get inputs and validate ranks. + ShapeHandle boxes; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes)); + ShapeHandle scores; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores)); + ShapeHandle max_output_size; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size)); + ShapeHandle iou_threshold; + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold)); + ShapeHandle score_threshold; + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold)); + // The boxes is a 2-D float Tensor of shape [num_boxes, 4]. + DimensionHandle unused; + // The boxes[0] and scores[0] are both num_boxes. + TF_RETURN_IF_ERROR(c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused)); + // The boxes[1] is 4. + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused)); + + c->set_output(0, c->Vector(c->UnknownDim())); + return Status::OK(); +} + } // namespace // -------------------------------------------------------------------------- @@ -694,29 +717,7 @@ REGISTER_OP("NonMaxSuppressionV3") .Input("iou_threshold: float") .Input("score_threshold: float") .Output("selected_indices: int32") - .SetShapeFn([](InferenceContext* c) { - // Get inputs and validate ranks. - ShapeHandle boxes; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes)); - ShapeHandle scores; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores)); - ShapeHandle max_output_size; - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size)); - ShapeHandle iou_threshold; - TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold)); - ShapeHandle score_threshold; - TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold)); - // The boxes is a 2-D float Tensor of shape [num_boxes, 4]. - DimensionHandle unused; - // The boxes[0] and scores[0] are both num_boxes. - TF_RETURN_IF_ERROR( - c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused)); - // The boxes[1] is 4. - TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused)); - - c->set_output(0, c->Vector(c->UnknownDim())); - return Status::OK(); - }); + .SetShapeFn(NMSShapeFn); REGISTER_OP("NonMaxSuppressionV4") .Input("boxes: float") @@ -728,26 +729,16 @@ REGISTER_OP("NonMaxSuppressionV4") .Output("valid_outputs: int32") .Attr("pad_to_max_output_size: bool = false") .SetShapeFn([](InferenceContext* c) { - // Get inputs and validate ranks. - ShapeHandle boxes; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes)); - ShapeHandle scores; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores)); - ShapeHandle max_output_size; - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size)); - ShapeHandle iou_threshold; - TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold)); - ShapeHandle score_threshold; - TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold)); - // The boxes is a 2-D float Tensor of shape [num_boxes, 4]. - DimensionHandle unused; - // The boxes[0] and scores[0] are both num_boxes. - TF_RETURN_IF_ERROR( - c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused)); - // The boxes[1] is 4. - TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused)); - - c->set_output(0, c->Vector(c->UnknownDim())); + TF_RETURN_IF_ERROR(NMSShapeFn(c)); + + bool pad_to_max; + TF_RETURN_IF_ERROR(c->GetAttr("pad_to_max_output_size", &pad_to_max)); + if (pad_to_max) { + // If padded, overwrite the shape of the output to be static. + DimensionHandle output_dim; + TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &output_dim)); + c->set_output(0, c->MakeShape({output_dim})); + } c->set_output(1, c->MakeShape({})); return Status::OK(); }); |