diff options
Diffstat (limited to 'tensorflow/python/ops/image_ops_impl.py')
-rw-r--r-- | tensorflow/python/ops/image_ops_impl.py | 56 |
1 files changed, 56 insertions, 0 deletions
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 639df5d845..15694d4b3f 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_image_ops from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.ops import string_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables @@ -1204,3 +1205,58 @@ def adjust_saturation(image, saturation_factor, name=None): rgb_altered = gen_image_ops.hsv_to_rgb(hsv_altered) return convert_image_dtype(rgb_altered, orig_dtype) + + +def decode_image(contents, channels=None, name=None): + """Convenience function for `decode_gif`, `decode_jpeg`, and `decode_png`. + Detects whether an image is a GIF, JPEG, or PNG, and performs the appropriate + operation to convert the input bytes `string` into a `Tensor` of type `uint8`. + + Note: `decode_gif` returns a 4-D array `[num_frames, height, width, 3]`, as + opposed to `decode_jpeg` and `decode_png`, which return 3-D arrays + `[height, width, num_channels]`. Make sure to take this into account when + constructing your graph if you are intermixing GIF files with JPEG and/or PNG + files. + + Args: + contents: 0-D `string`. The encoded image bytes. + channels: An optional `int`. Defaults to `0`. Number of color channels for + the decoded image. + name: A name for the operation (optional) + + Returns: + `Tensor` with type `uint8` with shape `[height, width, num_channels]` for + JPEG and PNG images and shape `[num_frames, height, width, 3]` for GIF + images. + """ + with ops.name_scope(name, 'decode_image') as scope: + if channels not in (None, 0, 1, 3): + raise ValueError('channels must be in (None, 0, 1, 3)') + substr = string_ops.substr(contents, 0, 4) + + def _gif(): + # Create assert op to check that bytes are GIF decodable + is_gif = math_ops.equal(substr, b'\x47\x49\x46\x38', name='is_gif') + decode_msg = 'Unable to decode bytes as JPEG, PNG, or GIF' + assert_decode = control_flow_ops.Assert(is_gif, [decode_msg]) + # Create assert to make sure that channels is not set to 1 + # Already checked above that channels is in (None, 0, 1, 3) + gif_channels = 0 if channels is None else channels + good_channels = math_ops.not_equal(gif_channels, 1, name='check_channels') + channels_msg = 'Channels must be in (None, 0, 3) when decoding GIF images' + assert_channels = control_flow_ops.Assert(good_channels, [channels_msg]) + with ops.control_dependencies([assert_decode, assert_channels]): + return gen_image_ops.decode_gif(contents) + + def _png(): + return gen_image_ops.decode_png(contents, channels) + + def check_png(): + is_png = math_ops.equal(substr, b'\211PNG', name='is_png') + return control_flow_ops.cond(is_png, _png, _gif, name='cond_png') + + def _jpeg(): + return gen_image_ops.decode_jpeg(contents, channels) + + is_jpeg = math_ops.equal(substr, b'\xff\xd8\xff\xe0', name='is_jpeg') + return control_flow_ops.cond(is_jpeg, _jpeg, check_png, name='cond_jpeg') |