diff options
Diffstat (limited to 'tensorflow/python/ops/image_ops_impl.py')
-rw-r--r-- | tensorflow/python/ops/image_ops_impl.py | 46 |
1 files changed, 30 insertions, 16 deletions
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index d860a3b618..b713c44717 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -148,6 +148,28 @@ def _Check3DImage(image, require_static=True): return [] +def _Assert3DImage(image): + """Assert that we are working with a properly shaped image. + + Performs the check statically if possible (i.e. if the shape + is statically known). Otherwise adds a control dependency + to an assert op that checks the dynamic shape. + + Args: + image: 3-D Tensor of shape [height, width, channels] + + Raises: + ValueError: if `image.shape` is not a 3-vector. + + Returns: + If the shape of `image` could be verified statically, `image` is + returned unchanged, otherwise there will be a control dependency + added that asserts the correct dynamic shape. + """ + return control_flow_ops.with_dependencies( + _Check3DImage(image, require_static=False), image) + + def _CheckAtLeast3DImage(image, require_static=True): """Assert that we are working with properly shaped image. @@ -223,8 +245,7 @@ def random_flip_up_down(image, seed=None): """ with ops.name_scope(None, 'random_flip_up_down', [image]) as scope: image = ops.convert_to_tensor(image, name='image') - image = control_flow_ops.with_dependencies( - _Check3DImage(image, require_static=False), image) + image = _Assert3DImage(image) uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed) mirror_cond = math_ops.less(uniform_random, .5) result = control_flow_ops.cond(mirror_cond, @@ -255,8 +276,7 @@ def random_flip_left_right(image, seed=None): """ with ops.name_scope(None, 'random_flip_left_right', [image]) as scope: image = ops.convert_to_tensor(image, name='image') - image = control_flow_ops.with_dependencies( - _Check3DImage(image, require_static=False), image) + image = _Assert3DImage(image) uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed) mirror_cond = math_ops.less(uniform_random, .5) result = control_flow_ops.cond(mirror_cond, @@ -286,8 +306,7 @@ def flip_left_right(image): """ with ops.name_scope(None, 'flip_left_right', [image]) as scope: image = ops.convert_to_tensor(image, name='image') - image = control_flow_ops.with_dependencies( - _Check3DImage(image, require_static=False), image) + image = _Assert3DImage(image) return fix_image_flip_shape(image, array_ops.reverse(image, [1], name=scope)) @@ -312,8 +331,7 @@ def flip_up_down(image): """ with ops.name_scope(None, 'flip_up_down', [image]) as scope: image = ops.convert_to_tensor(image, name='image') - image = control_flow_ops.with_dependencies( - _Check3DImage(image, require_static=False), image) + image = _Assert3DImage(image) return fix_image_flip_shape(image, array_ops.reverse(image, [0], name=scope)) @@ -332,8 +350,7 @@ def rot90(image, k=1, name=None): """ with ops.name_scope(name, 'rot90', [image, k]) as scope: image = ops.convert_to_tensor(image, name='image') - image = control_flow_ops.with_dependencies( - _Check3DImage(image, require_static=False), image) + image = _Assert3DImage(image) k = ops.convert_to_tensor(k, dtype=dtypes.int32, name='k') k.get_shape().assert_has_rank(0) k = math_ops.mod(k, 4) @@ -373,8 +390,7 @@ def transpose_image(image): """ with ops.name_scope(None, 'transpose_image', [image]) as scope: image = ops.convert_to_tensor(image, name='image') - image = control_flow_ops.with_dependencies( - _Check3DImage(image, require_static=False), image) + image = _Assert3DImage(image) return array_ops.transpose(image, [1, 0, 2], name=scope) @@ -410,8 +426,7 @@ def central_crop(image, central_fraction): if central_fraction == 1.0: return image - image = control_flow_ops.with_dependencies( - _Check3DImage(image, require_static=False), image) + image = _Assert3DImage(image) img_shape = array_ops.shape(image) depth = image.get_shape()[2] @@ -848,8 +863,7 @@ def per_image_standardization(image): """ with ops.name_scope(None, 'per_image_standardization', [image]) as scope: image = ops.convert_to_tensor(image, name='image') - image = control_flow_ops.with_dependencies( - _Check3DImage(image, require_static=False), image) + image = _Assert3DImage(image) num_pixels = math_ops.reduce_prod(array_ops.shape(image)) image = math_ops.cast(image, dtype=dtypes.float32) |