aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/image_ops.cc
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-02-07 14:36:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-07 14:39:49 -0800
commitd90054e7c0f41f4bab81df0548577a73b939a87a (patch)
treea15aea686a9d3f305e316d2a6ada0859ad8170d1 /tensorflow/core/ops/image_ops.cc
parent8461760f9f6cde8ed97507484d2a879140141032 (diff)
Merge changes from github.
PiperOrigin-RevId: 184897758
Diffstat (limited to 'tensorflow/core/ops/image_ops.cc')
-rw-r--r--tensorflow/core/ops/image_ops.cc24
1 files changed, 24 insertions, 0 deletions
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index ef2ac267cc..a62e2d782b 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -586,6 +586,17 @@ REGISTER_OP("NonMaxSuppression")
.Output("selected_indices: int32")
.Attr("iou_threshold: float = 0.5")
.SetShapeFn([](InferenceContext* c) {
+ // Get inputs and validate ranks.
+ ShapeHandle boxes;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
+ ShapeHandle scores;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
+ ShapeHandle max_output_size;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
+ // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
+ DimensionHandle unused;
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
+
c->set_output(0, c->Vector(c->UnknownDim()));
return Status::OK();
});
@@ -597,6 +608,19 @@ REGISTER_OP("NonMaxSuppressionV2")
.Input("iou_threshold: float")
.Output("selected_indices: int32")
.SetShapeFn([](InferenceContext* c) {
+ // Get inputs and validate ranks.
+ ShapeHandle boxes;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
+ ShapeHandle scores;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
+ ShapeHandle max_output_size;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
+ ShapeHandle iou_threshold;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
+ // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
+ DimensionHandle unused;
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
+
c->set_output(0, c->Vector(c->UnknownDim()));
return Status::OK();
});