diff options
author | Mingxing Tan <tanmingxing@google.com> | 2017-08-29 15:08:38 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-29 15:15:14 -0700 |
commit | 0599901f020dd2e74b93bc4e12364027fbde1f1c (patch) | |
tree | 05533c0389a681349c2162460432319732f621c2 | |
parent | f282bb142f8d170549f6708f334bed7f40dd029e (diff) |
Add extract_jpeg_shape such that we can get the image shape without actually
decoding the image.
PiperOrigin-RevId: 166910421
-rw-r--r-- | tensorflow/core/kernels/BUILD | 8 | ||||
-rw-r--r-- | tensorflow/core/kernels/extract_jpeg_shape_op.cc | 77 | ||||
-rw-r--r-- | tensorflow/core/ops/image_ops.cc | 22 | ||||
-rw-r--r-- | tensorflow/core/ops/image_ops_test.cc | 10 | ||||
-rw-r--r-- | tensorflow/python/ops/image_ops.py | 1 | ||||
-rw-r--r-- | tensorflow/python/ops/image_ops_test.py | 20 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/tensorflow.image.pbtxt | 4 |
7 files changed, 142 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index cf089a40ba..79a7b8a8b9 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1807,6 +1807,7 @@ cc_library( ":draw_bounding_box_op", ":encode_jpeg_op", ":encode_png_op", + ":extract_jpeg_shape_op", ":non_max_suppression_op", ":random_crop_op", ":resize_area_op", @@ -1904,6 +1905,12 @@ tf_kernel_library( ) tf_kernel_library( + name = "extract_jpeg_shape_op", + prefix = "extract_jpeg_shape_op", + deps = IMAGE_DEPS, +) + +tf_kernel_library( name = "non_max_suppression_op", prefix = "non_max_suppression_op", deps = IMAGE_DEPS, @@ -4651,6 +4658,7 @@ filegroup( "decode_image_op.*", "encode_png_op.*", "encode_jpeg_op.*", + "extract_jpeg_shape_op.*", "decode_jpeg_op.*", "decode_gif_op.*", "identity_reader_op.*", diff --git a/tensorflow/core/kernels/extract_jpeg_shape_op.cc b/tensorflow/core/kernels/extract_jpeg_shape_op.cc new file mode 100644 index 0000000000..60d798af56 --- /dev/null +++ b/tensorflow/core/kernels/extract_jpeg_shape_op.cc @@ -0,0 +1,77 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/image_ops.cc + +#include <memory> +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/jpeg/jpeg_mem.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +// Extract the shape of a JPEG image. +template <typename T> +class ExtractJpegShapeOp : public OpKernel { + public: + explicit ExtractJpegShapeOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Get input content. + const Tensor& contents = context->input(0); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents.shape()), + errors::InvalidArgument("contents must be scalar, got shape ", + contents.shape().DebugString())); + const StringPiece input = contents.scalar<string>()(); + OP_REQUIRES(context, input.size() <= std::numeric_limits<int>::max(), + errors::InvalidArgument("JPEG contents are too large for int: ", + input.size())); + + // Call GetImageInfo to get image shape. + int width, height, components; + OP_REQUIRES( + context, + jpeg::GetImageInfo(input.data(), input.size(), &width, &height, + &components), + errors::InvalidArgument("Invalid JPEG data, size ", input.size())); + // Allocate tensor and set shape size. + Tensor* image_shape = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({3}), &image_shape)); + auto image_shape_data = image_shape->tensor<T, 1>(); + image_shape_data(0) = height; + image_shape_data(1) = width; + image_shape_data(2) = components; + } +}; + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("ExtractJpegShape") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("output_type"), \ + ExtractJpegShapeOp<type>) + +TF_CALL_int32(REGISTER_KERNELS); +TF_CALL_int64(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +} // namespace tensorflow diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index 1bfa37f5a7..8ddf3561ce 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -458,6 +458,28 @@ contents: 0-D. JPEG-encoded image. )doc"); // -------------------------------------------------------------------------- +REGISTER_OP("ExtractJpegShape") + .Input("contents: string") + .Output("image_shape: output_type") + .Attr("output_type: {int32, int64} = DT_INT32") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + c->set_output(0, c->Vector(3)); + return Status::OK(); + }) + .Doc(R"doc( +Extract the shape information of a JPEG-encoded image. + +This op only parses the image header, so it is much faster than DecodeJpeg. + +contents: 0-D. The JPEG-encoded image. +image_shape: 1-D. The image shape with format [height, width, channels]. +output_type: (Optional) The output type of the operation (int32 or int64). + Defaults to int32. +)doc"); + +// -------------------------------------------------------------------------- REGISTER_OP("AdjustContrast") .Input("images: T") .Input("contrast_factor: float") diff --git a/tensorflow/core/ops/image_ops_test.cc b/tensorflow/core/ops/image_ops_test.cc index ea202edfb3..c757cefdda 100644 --- a/tensorflow/core/ops/image_ops_test.cc +++ b/tensorflow/core/ops/image_ops_test.cc @@ -101,6 +101,16 @@ TEST(ImageOpsTest, EncodeImage_ShapeFn) { } } +TEST(ImageOpsTest, ExtractJpegShape_ShapeFn) { + ShapeInferenceTestOp op("ExtractJpegShape"); + + // Rank check. + INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1]"); + + // Only specify input data. Output must be a 1-D tensor with 3 elements. + INFER_OK(op, "?", "[3]"); +} + TEST(ImageOpsTest, Colorspace_ShapeFn) { for (const char* op_name : {"HSVToRGB", "RGBToHSV"}) { ShapeInferenceTestOp op(op_name); diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py index 75c67dcb3c..31485ae9d4 100644 --- a/tensorflow/python/ops/image_ops.py +++ b/tensorflow/python/ops/image_ops.py @@ -22,6 +22,7 @@ See the @{$python/image} guide. @@decode_gif @@decode_jpeg @@encode_jpeg +@@extract_jpeg_shape @@decode_png @@encode_png @@decode_image diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index 6356d88403..9e656f0e08 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -2455,6 +2455,26 @@ class JpegTest(test_util.TensorFlowTestCase): self.assertEqual(image.get_shape().as_list(), [None, None, channels or None]) + def testExtractJpegShape(self): + # Read a real jpeg and verify shape. + path = ("tensorflow/core/lib/jpeg/testdata/" + "jpeg_merge_test1.jpg") + with self.test_session(use_gpu=True) as sess: + jpeg = io_ops.read_file(path) + # Extract shape without decoding. + [image_shape] = sess.run([image_ops.extract_jpeg_shape(jpeg)]) + self.assertEqual(image_shape.tolist(), [256, 128, 3]) + + def testExtractJpegShapeforCmyk(self): + # Read a cmyk jpeg image, and verify its shape. + path = ("tensorflow/core/lib/jpeg/testdata/" + "jpeg_merge_test1_cmyk.jpg") + with self.test_session(use_gpu=True) as sess: + jpeg = io_ops.read_file(path) + [image_shape] = sess.run([image_ops.extract_jpeg_shape(jpeg)]) + # Cmyk jpeg image has 4 channels. + self.assertEqual(image_shape.tolist(), [256, 128, 4]) + class PngTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt index 661d6fc586..764ffbb4b7 100644 --- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.image.pbtxt @@ -77,6 +77,10 @@ tf_module { argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'None\'], " } member_method { + name: "extract_jpeg_shape" + argspec: "args=[\'contents\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], " + } + member_method { name: "flip_left_right" argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None" } |