aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mingxing Tan <tanmingxing@google.com>2017-08-29 15:08:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-29 15:15:14 -0700
commit0599901f020dd2e74b93bc4e12364027fbde1f1c (patch)
tree05533c0389a681349c2162460432319732f621c2
parentf282bb142f8d170549f6708f334bed7f40dd029e (diff)
Add extract_jpeg_shape such that we can get the image shape without actually
decoding the image. PiperOrigin-RevId: 166910421
-rw-r--r--tensorflow/core/kernels/BUILD8
-rw-r--r--tensorflow/core/kernels/extract_jpeg_shape_op.cc77
-rw-r--r--tensorflow/core/ops/image_ops.cc22
-rw-r--r--tensorflow/core/ops/image_ops_test.cc10
-rw-r--r--tensorflow/python/ops/image_ops.py1
-rw-r--r--tensorflow/python/ops/image_ops_test.py20
-rw-r--r--tensorflow/tools/api/golden/tensorflow.image.pbtxt4
7 files changed, 142 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index cf089a40ba..79a7b8a8b9 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1807,6 +1807,7 @@ cc_library(
":draw_bounding_box_op",
":encode_jpeg_op",
":encode_png_op",
+ ":extract_jpeg_shape_op",
":non_max_suppression_op",
":random_crop_op",
":resize_area_op",
@@ -1904,6 +1905,12 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "extract_jpeg_shape_op",
+ prefix = "extract_jpeg_shape_op",
+ deps = IMAGE_DEPS,
+)
+
+tf_kernel_library(
name = "non_max_suppression_op",
prefix = "non_max_suppression_op",
deps = IMAGE_DEPS,
@@ -4651,6 +4658,7 @@ filegroup(
"decode_image_op.*",
"encode_png_op.*",
"encode_jpeg_op.*",
+ "extract_jpeg_shape_op.*",
"decode_jpeg_op.*",
"decode_gif_op.*",
"identity_reader_op.*",
diff --git a/tensorflow/core/kernels/extract_jpeg_shape_op.cc b/tensorflow/core/kernels/extract_jpeg_shape_op.cc
new file mode 100644
index 0000000000..60d798af56
--- /dev/null
+++ b/tensorflow/core/kernels/extract_jpeg_shape_op.cc
@@ -0,0 +1,77 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/image_ops.cc
+
+#include <memory>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/jpeg/jpeg_mem.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+// Extract the shape of a JPEG image.
+template <typename T>
+class ExtractJpegShapeOp : public OpKernel {
+ public:
+ explicit ExtractJpegShapeOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ // Get input content.
+ const Tensor& contents = context->input(0);
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents.shape()),
+ errors::InvalidArgument("contents must be scalar, got shape ",
+ contents.shape().DebugString()));
+ const StringPiece input = contents.scalar<string>()();
+ OP_REQUIRES(context, input.size() <= std::numeric_limits<int>::max(),
+ errors::InvalidArgument("JPEG contents are too large for int: ",
+ input.size()));
+
+ // Call GetImageInfo to get image shape.
+ int width, height, components;
+ OP_REQUIRES(
+ context,
+ jpeg::GetImageInfo(input.data(), input.size(), &width, &height,
+ &components),
+ errors::InvalidArgument("Invalid JPEG data, size ", input.size()));
+ // Allocate tensor and set shape size.
+ Tensor* image_shape = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({3}), &image_shape));
+ auto image_shape_data = image_shape->tensor<T, 1>();
+ image_shape_data(0) = height;
+ image_shape_data(1) = width;
+ image_shape_data(2) = components;
+ }
+};
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("ExtractJpegShape") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("output_type"), \
+ ExtractJpegShapeOp<type>)
+
+TF_CALL_int32(REGISTER_KERNELS);
+TF_CALL_int64(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index 1bfa37f5a7..8ddf3561ce 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -458,6 +458,28 @@ contents: 0-D. JPEG-encoded image.
)doc");
// --------------------------------------------------------------------------
+REGISTER_OP("ExtractJpegShape")
+ .Input("contents: string")
+ .Output("image_shape: output_type")
+ .Attr("output_type: {int32, int64} = DT_INT32")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ c->set_output(0, c->Vector(3));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Extract the shape information of a JPEG-encoded image.
+
+This op only parses the image header, so it is much faster than DecodeJpeg.
+
+contents: 0-D. The JPEG-encoded image.
+image_shape: 1-D. The image shape with format [height, width, channels].
+output_type: (Optional) The output type of the operation (int32 or int64).
+ Defaults to int32.
+)doc");
+
+// --------------------------------------------------------------------------
REGISTER_OP("AdjustContrast")
.Input("images: T")
.Input("contrast_factor: float")
diff --git a/tensorflow/core/ops/image_ops_test.cc b/tensorflow/core/ops/image_ops_test.cc
index ea202edfb3..c757cefdda 100644
--- a/tensorflow/core/ops/image_ops_test.cc
+++ b/tensorflow/core/ops/image_ops_test.cc
@@ -101,6 +101,16 @@ TEST(ImageOpsTest, EncodeImage_ShapeFn) {
}
}
+TEST(ImageOpsTest, ExtractJpegShape_ShapeFn) {
+ ShapeInferenceTestOp op("ExtractJpegShape");
+
+ // Rank check.
+ INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1]");
+
+ // Only specify input data. Output must be a 1-D tensor with 3 elements.
+ INFER_OK(op, "?", "[3]");
+}
+
TEST(ImageOpsTest, Colorspace_ShapeFn) {
for (const char* op_name : {"HSVToRGB", "RGBToHSV"}) {
ShapeInferenceTestOp op(op_name);
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
index 75c67dcb3c..31485ae9d4 100644
--- a/tensorflow/python/ops/image_ops.py
+++ b/tensorflow/python/ops/image_ops.py
@@ -22,6 +22,7 @@ See the @{$python/image} guide.
@@decode_gif
@@decode_jpeg
@@encode_jpeg
+@@extract_jpeg_shape
@@decode_png
@@encode_png
@@decode_image
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 6356d88403..9e656f0e08 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -2455,6 +2455,26 @@ class JpegTest(test_util.TensorFlowTestCase):
self.assertEqual(image.get_shape().as_list(),
[None, None, channels or None])
+ def testExtractJpegShape(self):
+ # Read a real jpeg and verify shape.
+ path = ("tensorflow/core/lib/jpeg/testdata/"
+ "jpeg_merge_test1.jpg")
+ with self.test_session(use_gpu=True) as sess:
+ jpeg = io_ops.read_file(path)
+ # Extract shape without decoding.
+ [image_shape] = sess.run([image_ops.extract_jpeg_shape(jpeg)])
+ self.assertEqual(image_shape.tolist(), [256, 128, 3])
+
+ def testExtractJpegShapeforCmyk(self):
+ # Read a cmyk jpeg image, and verify its shape.
+ path = ("tensorflow/core/lib/jpeg/testdata/"
+ "jpeg_merge_test1_cmyk.jpg")
+ with self.test_session(use_gpu=True) as sess:
+ jpeg = io_ops.read_file(path)
+ [image_shape] = sess.run([image_ops.extract_jpeg_shape(jpeg)])
+ # Cmyk jpeg image has 4 channels.
+ self.assertEqual(image_shape.tolist(), [256, 128, 4])
+
class PngTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
index 661d6fc586..764ffbb4b7 100644
--- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
@@ -77,6 +77,10 @@ tf_module {
argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'None\'], "
}
member_method {
+ name: "extract_jpeg_shape"
+ argspec: "args=[\'contents\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
+ }
+ member_method {
name: "flip_left_right"
argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None"
}