aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/image_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/image_ops.cc')
-rw-r--r--tensorflow/core/ops/image_ops.cc77
1 files changed, 34 insertions, 43 deletions
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index 81f324a3ef..11ca0bd259 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -108,6 +108,29 @@ Status ColorspaceShapeFn(InferenceContext* c) {
return Status::OK();
}
+Status NMSShapeFn(InferenceContext* c) {
+ // Get inputs and validate ranks.
+ ShapeHandle boxes;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
+ ShapeHandle scores;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
+ ShapeHandle max_output_size;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
+ ShapeHandle iou_threshold;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
+ ShapeHandle score_threshold;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
+ // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
+ DimensionHandle unused;
+ // The boxes[0] and scores[0] are both num_boxes.
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
+ // The boxes[1] is 4.
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
+
+ c->set_output(0, c->Vector(c->UnknownDim()));
+ return Status::OK();
+}
+
} // namespace
// --------------------------------------------------------------------------
@@ -694,29 +717,7 @@ REGISTER_OP("NonMaxSuppressionV3")
.Input("iou_threshold: float")
.Input("score_threshold: float")
.Output("selected_indices: int32")
- .SetShapeFn([](InferenceContext* c) {
- // Get inputs and validate ranks.
- ShapeHandle boxes;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
- ShapeHandle scores;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
- ShapeHandle max_output_size;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
- ShapeHandle iou_threshold;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
- ShapeHandle score_threshold;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
- // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
- DimensionHandle unused;
- // The boxes[0] and scores[0] are both num_boxes.
- TF_RETURN_IF_ERROR(
- c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
- // The boxes[1] is 4.
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
-
- c->set_output(0, c->Vector(c->UnknownDim()));
- return Status::OK();
- });
+ .SetShapeFn(NMSShapeFn);
REGISTER_OP("NonMaxSuppressionV4")
.Input("boxes: float")
@@ -728,26 +729,16 @@ REGISTER_OP("NonMaxSuppressionV4")
.Output("valid_outputs: int32")
.Attr("pad_to_max_output_size: bool = false")
.SetShapeFn([](InferenceContext* c) {
- // Get inputs and validate ranks.
- ShapeHandle boxes;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
- ShapeHandle scores;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
- ShapeHandle max_output_size;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
- ShapeHandle iou_threshold;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
- ShapeHandle score_threshold;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
- // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
- DimensionHandle unused;
- // The boxes[0] and scores[0] are both num_boxes.
- TF_RETURN_IF_ERROR(
- c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
- // The boxes[1] is 4.
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
-
- c->set_output(0, c->Vector(c->UnknownDim()));
+ TF_RETURN_IF_ERROR(NMSShapeFn(c));
+
+ bool pad_to_max;
+ TF_RETURN_IF_ERROR(c->GetAttr("pad_to_max_output_size", &pad_to_max));
+ if (pad_to_max) {
+ // If padded, overwrite the shape of the output to be static.
+ DimensionHandle output_dim;
+ TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &output_dim));
+ c->set_output(0, c->MakeShape({output_dim}));
+ }
c->set_output(1, c->MakeShape({}));
return Status::OK();
});