diff options
author | 2018-05-14 17:57:23 -0700 | |
---|---|---|
committer | 2018-05-14 17:57:23 -0700 | |
commit | b84bff476514c1a2ee80d9f1bc31a9cb5dcc2ee5 (patch) | |
tree | 4399b9cc9332ba07fb9342f0b4ef5933f29ef187 /tensorflow | |
parent | 69c74f1e74eb5da964638533d594475ee9e54a66 (diff) |
Improve shape function of `tf.image.draw_bounding_boxes` (#19237)
* Improve shape function of `tf.image.draw_bounding_boxes`
The `tf.image.draw_bounding_boxes` requires `boxes` to be
3-D shape though there was no check on shape function.
This fix improves the shape function by restricting the
boxes to 3-D.
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add additional check to make sure boxes shape
ends with 4 ([batch, num_bounding_boxes, 4])
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Address review feedback with addtional shape checks.
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add unit tests
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/core/ops/image_ops.cc | 19 | ||||
-rw-r--r-- | tensorflow/core/ops/image_ops_test.cc | 19 |
2 files changed, 38 insertions, 0 deletions
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index c3b08e067a..cccfc4736e 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -435,6 +435,25 @@ REGISTER_OP("DrawBoundingBoxes") .Output("output: T") .Attr("T: {float, half} = DT_FLOAT") .SetShapeFn([](InferenceContext* c) { + // The rank of images should be 4. + ShapeHandle images; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &images)); + // Channel depth should be either 1 (GRY), 3 (RGB), or 4 (RGBA). + if (c->ValueKnown(c->Dim(images, 3))) { + int64 depth = c->Value(c->Dim(images, 3)); + if (!(depth == 1 || depth == 3 || depth == 4)) { + return errors::InvalidArgument("Channel depth should be either 1 (GRY), " + "3 (RGB), or 4 (RGBA)"); + } + } + + // The rank of boxes is 3: [batch, num_bounding_boxes, 4]. + ShapeHandle boxes; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &boxes)); + // The last value of boxes shape is 4. + DimensionHandle unused; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 2), 4, &unused)); + return shape_inference::UnchangedShapeWithRankAtLeast(c, 3); }); diff --git a/tensorflow/core/ops/image_ops_test.cc b/tensorflow/core/ops/image_ops_test.cc index 5f0b391b0d..517af26b44 100644 --- a/tensorflow/core/ops/image_ops_test.cc +++ b/tensorflow/core/ops/image_ops_test.cc @@ -312,4 +312,23 @@ TEST(ImageOpsTest, QuantizedResizeBilinear_ShapeFn) { INFER_OK(op, "[1,?,3,?];[2];[];[]", "[d0_0,20,30,d0_3];[];[]"); } +TEST(ImageOpsTest, DrawBoundingBoxes_ShapeFn) { + ShapeInferenceTestOp op("DrawBoundingBoxes"); + op.input_tensors.resize(2); + + // Check images. + INFER_ERROR("must be rank 4", op, "[1,?,3];?"); + INFER_ERROR("should be either 1 (GRY), 3 (RGB), or 4 (RGBA)", + op, "[1,?,?,5];?"); + + // Check boxes. + INFER_ERROR("must be rank 3", op, "[1,?,?,4];[1,4]"); + INFER_ERROR("Dimension must be 4", op, "[1,?,?,4];[1,2,2]"); + + // OK shapes. + INFER_OK(op, "[4,?,?,4];?", "in0"); + INFER_OK(op, "[?,?,?,?];[?,?,?]", "in0"); + INFER_OK(op, "[4,?,?,4];[?,?,?]", "in0"); + INFER_OK(op, "[4,?,?,4];[?,?,4]", "in0"); +} } // end namespace tensorflow |