From 568127ac3b8e501bb230ee287ec9a46129fad349 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 17 Oct 2017 12:23:29 -0700 Subject: Improve shape inference with `DecodeAndCropJpeg` (#13750) * Improve shape inference with `DecodeAndCropJpeg` While working on improving shape inference for several other ops in 13561 and 13193, I noticed that `DecodeAndCropJpeg` does not inference shape even though crop size might have already be provided. In that case the shape will be `[h, w, channel]` and `h`, `w` is part of the `crop_window`. This fix updates the shape function in `DecodeAndCropJpeg` for improving shape inference. Signed-off-by: Yong Tang * Add test cases to cover shape inference for `DecodeAndCropJpeg` Signed-off-by: Yong Tang * Address failed unit tests Signed-off-by: Yong Tang --- tensorflow/core/ops/image_ops.cc | 31 ++++++++++++++++++++++++++++++- tensorflow/core/ops/image_ops_test.cc | 6 +++--- tensorflow/python/ops/image_ops_test.py | 6 +++++- 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index 66765a3333..89c9da81c5 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -453,7 +453,36 @@ REGISTER_OP("DecodeAndCropJpeg") .Attr("acceptable_fraction: float = 1.0") .Attr("dct_method: string = ''") .Output("image: uint8") - .SetShapeFn(DecodeImageShapeFn) + .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + DimensionHandle channels_dim = c->UnknownDim(); + DimensionHandle h = c->UnknownDim(); + DimensionHandle w = c->UnknownDim(); + + int32 channels; + TF_RETURN_IF_ERROR(c->GetAttr("channels", &channels)); + if (channels != 0) { + if (channels < 0) { + return errors::InvalidArgument("channels must be non-negative, got ", + channels); + } + channels_dim = c->MakeDim(channels); + } + + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(unused, 0), 4, &unused_dim)); + + const Tensor* crop_window = c->input_tensor(1); + if (crop_window != nullptr) { + auto crop_window_vec = crop_window->vec(); + h = c->MakeDim(crop_window_vec(2)); + w = c->MakeDim(crop_window_vec(3)); + } + c->set_output(0, c->MakeShape({h, w, channels_dim})); + return Status::OK(); + }) .Doc(strings::StrCat(R"doc( Decode and Crop a JPEG-encoded image to a uint8 tensor. )doc", diff --git a/tensorflow/core/ops/image_ops_test.cc b/tensorflow/core/ops/image_ops_test.cc index c34b11a15e..5f0b391b0d 100644 --- a/tensorflow/core/ops/image_ops_test.cc +++ b/tensorflow/core/ops/image_ops_test.cc @@ -105,7 +105,7 @@ TEST(ImageOpsTest, DecodeAndCropJpeg_ShapeFn) { .Input({"img", 0, DT_STRING}) .Input({"crop_window", 1, DT_INT32}) .Finalize(&op.node_def)); - INFER_OK(op, "[];[]", "[?,?,?]"); + INFER_OK(op, "[];[?]", "[?,?,?]"); // Set the channel, so that part of output shape is known. TF_ASSERT_OK(NodeDefBuilder("test", op_name) @@ -113,7 +113,7 @@ TEST(ImageOpsTest, DecodeAndCropJpeg_ShapeFn) { .Input({"crop_window", 1, DT_INT32}) .Attr("channels", 4) .Finalize(&op.node_def)); - INFER_OK(op, "[];[]", "[?,?,4]"); + INFER_OK(op, "[];[?]", "[?,?,4]"); // Negative channel value is rejected. TF_ASSERT_OK(NodeDefBuilder("test", op_name) @@ -139,7 +139,7 @@ TEST(ImageOpsTest, DecodeAndCropJpeg_InvalidCropWindow) { .Input({"img", 0, DT_STRING}) .Input({"crop_window", 1, DT_INT32}) .Finalize(&op.node_def)); - INFER_OK(op, "[];[]", "[?,?,?]"); + INFER_OK(op, "[];[?]", "[?,?,?]"); } TEST(ImageOpsTest, EncodeImage_ShapeFn) { diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index 348c005ff3..b13b73edbb 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -2434,9 +2434,13 @@ class JpegTest(test_util.TensorFlowTestCase): y, x, h, w = crop_window image1_crop = image_ops.crop_to_bounding_box(image1, y, x, h, w) - # Combined crop+decode. + # Combined decode+crop. image2 = image_ops.decode_and_crop_jpeg(jpeg0, crop_window) + # Combined decode+crop should have the same shape inference + self.assertAllEqual(image1_crop.get_shape().as_list(), + image2.get_shape().as_list()) + # CropAndDecode should be equal to DecodeJpeg+Crop. image1_crop, image2 = sess.run([image1_crop, image2]) self.assertAllEqual(image1_crop, image2) -- cgit v1.2.3