aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/decode_png_op.cc
diff options
context:
space:
mode:
authorGravatar Geoffrey Irving <geoffreyi@google.com>2017-05-12 11:14:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-12 11:18:28 -0700
commit04891087aa887ac744c5ec2d991314b5b3c4f78e (patch)
tree90b8492e6475034ef39420c5a6c9baf25512c985 /tensorflow/core/kernels/decode_png_op.cc
parent365b5c1c322df7870688d9f08320d4012ef1b6bc (diff)
Make all the image decode ops handle all formats
Too many users try to decode pngs as jpegs. Now, if you pass a png to decode_jpeg, or a gif to decode_png, it silently does what the user was expecting. Unfortunately, tf.image.decode_image still exists as a separate thing in Python, since decode_gif returns 4-D shapes incompatible with the other ops. A future CL could clean that up, but it's hard to do in a backwards compatible way. As is, decode_png and decode_jpeg will bail if you try to decode an animated gif, and produce 3-D images for nonanimated gifs. Also fix some crash-on-error bugs and memory leak bugs in gif_io.cc. RELNOTES: Make decode_jpeg/decode_png/decode_gif handle all formats, since users frequently try to decode an image as the wrong type. Fixes #9786. PiperOrigin-RevId: 155888493
Diffstat (limited to 'tensorflow/core/kernels/decode_png_op.cc')
-rw-r--r--tensorflow/core/kernels/decode_png_op.cc118
1 files changed, 0 insertions, 118 deletions
diff --git a/tensorflow/core/kernels/decode_png_op.cc b/tensorflow/core/kernels/decode_png_op.cc
deleted file mode 100644
index 1906ae7746..0000000000
--- a/tensorflow/core/kernels/decode_png_op.cc
+++ /dev/null
@@ -1,118 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// See docs in ../ops/image_ops.cc
-
-#include <memory>
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/register_types.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/png/png_io.h"
-#include "tensorflow/core/platform/logging.h"
-
-namespace tensorflow {
-
-// Decode the contents of a PNG file
-class DecodePngOp : public OpKernel {
- public:
- explicit DecodePngOp(OpKernelConstruction* context) : OpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("channels", &channels_));
- OP_REQUIRES(context, channels_ == 0 || channels_ == 1 || channels_ == 3 ||
- channels_ == 4,
- errors::InvalidArgument("channels must be 0, 1, 3, or 4, got ",
- channels_));
-
- DataType dt;
- OP_REQUIRES_OK(context, context->GetAttr("dtype", &dt));
- OP_REQUIRES(
- context, dt == DataType::DT_UINT8 || dt == DataType::DT_UINT16,
- errors::InvalidArgument("Type must be UINT8 or UINT16, got ", dt));
- if (dt == DataType::DT_UINT8) {
- desired_channel_bits_ = 8;
- } else {
- desired_channel_bits_ = 16;
- }
- }
-
- void Compute(OpKernelContext* context) override {
- const Tensor& contents = context->input(0);
- OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents.shape()),
- errors::InvalidArgument("contents must be scalar, got shape ",
- contents.shape().DebugString()));
-
- // Start decoding image to get shape details
- const StringPiece data = contents.scalar<string>()();
- png::DecodeContext decode;
- OP_REQUIRES(
- context,
- png::CommonInitDecode(data, channels_, desired_channel_bits_, &decode),
- errors::InvalidArgument("Invalid PNG header, data size ", data.size()));
-
- // Verify that width and height are not too large:
- // - verify width and height don't overflow int.
- // - width can later be multiplied by channels_ and sizeof(uint16), so
- // verify single dimension is not too large.
- // - verify when width and height are multiplied together, there are a few
- // bits to spare as well.
- const int width = static_cast<int>(decode.width);
- const int height = static_cast<int>(decode.height);
- const int64 total_size =
- static_cast<int64>(width) * static_cast<int64>(height);
- if (width != static_cast<int64>(decode.width) || width <= 0 ||
- width >= (1LL << 27) || height != static_cast<int64>(decode.height) ||
- height <= 0 || height >= (1LL << 27) || total_size >= (1LL << 29)) {
- png::CommonFreeDecode(&decode);
- OP_REQUIRES(context, false,
- errors::InvalidArgument("PNG size too large for int: ",
- decode.width, " by ", decode.height));
- }
-
- // Allocate tensor
- Tensor* output = nullptr;
- const auto status = context->allocate_output(
- 0, TensorShape({height, width, decode.channels}), &output);
- if (!status.ok()) png::CommonFreeDecode(&decode);
- OP_REQUIRES_OK(context, status);
-
- if (desired_channel_bits_ == 8) {
- // Finish decoding image
- OP_REQUIRES(
- context,
- png::CommonFinishDecode(
- reinterpret_cast<png_bytep>(output->flat<uint8>().data()),
- decode.channels * width * sizeof(uint8), &decode),
- errors::InvalidArgument("Invalid PNG data, size ", data.size()));
- } else {
- // Finish decoding image
- OP_REQUIRES(
- context,
- png::CommonFinishDecode(
- reinterpret_cast<png_bytep>(output->flat<uint16>().data()),
- decode.channels * width * sizeof(uint16), &decode),
- errors::InvalidArgument("Invalid PNG data, size ", data.size()));
- }
- }
-
- private:
- int channels_;
- int desired_channel_bits_;
-};
-REGISTER_KERNEL_BUILDER(Name("DecodePng").Device(DEVICE_CPU), DecodePngOp);
-
-} // namespace tensorflow