aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Yong Tang <yong.tang.github@outlook.com>2018-05-14 17:57:23 -0700
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2018-05-14 17:57:23 -0700
commitb84bff476514c1a2ee80d9f1bc31a9cb5dcc2ee5 (patch)
tree4399b9cc9332ba07fb9342f0b4ef5933f29ef187 /tensorflow
parent69c74f1e74eb5da964638533d594475ee9e54a66 (diff)
Improve shape function of `tf.image.draw_bounding_boxes` (#19237)
* Improve shape function of `tf.image.draw_bounding_boxes` The `tf.image.draw_bounding_boxes` requires `boxes` to be 3-D shape though there was no check on shape function. This fix improves the shape function by restricting the boxes to 3-D. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add additional check to make sure boxes shape ends with 4 ([batch, num_bounding_boxes, 4]) Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Address review feedback with addtional shape checks. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add unit tests Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/ops/image_ops.cc19
-rw-r--r--tensorflow/core/ops/image_ops_test.cc19
2 files changed, 38 insertions, 0 deletions
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index c3b08e067a..cccfc4736e 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -435,6 +435,25 @@ REGISTER_OP("DrawBoundingBoxes")
.Output("output: T")
.Attr("T: {float, half} = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
+ // The rank of images should be 4.
+ ShapeHandle images;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &images));
+ // Channel depth should be either 1 (GRY), 3 (RGB), or 4 (RGBA).
+ if (c->ValueKnown(c->Dim(images, 3))) {
+ int64 depth = c->Value(c->Dim(images, 3));
+ if (!(depth == 1 || depth == 3 || depth == 4)) {
+ return errors::InvalidArgument("Channel depth should be either 1 (GRY), "
+ "3 (RGB), or 4 (RGBA)");
+ }
+ }
+
+ // The rank of boxes is 3: [batch, num_bounding_boxes, 4].
+ ShapeHandle boxes;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &boxes));
+ // The last value of boxes shape is 4.
+ DimensionHandle unused;
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 2), 4, &unused));
+
return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
});
diff --git a/tensorflow/core/ops/image_ops_test.cc b/tensorflow/core/ops/image_ops_test.cc
index 5f0b391b0d..517af26b44 100644
--- a/tensorflow/core/ops/image_ops_test.cc
+++ b/tensorflow/core/ops/image_ops_test.cc
@@ -312,4 +312,23 @@ TEST(ImageOpsTest, QuantizedResizeBilinear_ShapeFn) {
INFER_OK(op, "[1,?,3,?];[2];[];[]", "[d0_0,20,30,d0_3];[];[]");
}
+TEST(ImageOpsTest, DrawBoundingBoxes_ShapeFn) {
+ ShapeInferenceTestOp op("DrawBoundingBoxes");
+ op.input_tensors.resize(2);
+
+ // Check images.
+ INFER_ERROR("must be rank 4", op, "[1,?,3];?");
+ INFER_ERROR("should be either 1 (GRY), 3 (RGB), or 4 (RGBA)",
+ op, "[1,?,?,5];?");
+
+ // Check boxes.
+ INFER_ERROR("must be rank 3", op, "[1,?,?,4];[1,4]");
+ INFER_ERROR("Dimension must be 4", op, "[1,?,?,4];[1,2,2]");
+
+ // OK shapes.
+ INFER_OK(op, "[4,?,?,4];?", "in0");
+ INFER_OK(op, "[?,?,?,?];[?,?,?]", "in0");
+ INFER_OK(op, "[4,?,?,4];[?,?,?]", "in0");
+ INFER_OK(op, "[4,?,?,4];[?,?,4]", "in0");
+}
} // end namespace tensorflow