aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/image_ops_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/image_ops_impl.py')
-rw-r--r--tensorflow/python/ops/image_ops_impl.py46
1 files changed, 30 insertions, 16 deletions
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index d860a3b618..b713c44717 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -148,6 +148,28 @@ def _Check3DImage(image, require_static=True):
return []
+def _Assert3DImage(image):
+ """Assert that we are working with a properly shaped image.
+
+ Performs the check statically if possible (i.e. if the shape
+ is statically known). Otherwise adds a control dependency
+ to an assert op that checks the dynamic shape.
+
+ Args:
+ image: 3-D Tensor of shape [height, width, channels]
+
+ Raises:
+ ValueError: if `image.shape` is not a 3-vector.
+
+ Returns:
+ If the shape of `image` could be verified statically, `image` is
+ returned unchanged, otherwise there will be a control dependency
+ added that asserts the correct dynamic shape.
+ """
+ return control_flow_ops.with_dependencies(
+ _Check3DImage(image, require_static=False), image)
+
+
def _CheckAtLeast3DImage(image, require_static=True):
"""Assert that we are working with properly shaped image.
@@ -223,8 +245,7 @@ def random_flip_up_down(image, seed=None):
"""
with ops.name_scope(None, 'random_flip_up_down', [image]) as scope:
image = ops.convert_to_tensor(image, name='image')
- image = control_flow_ops.with_dependencies(
- _Check3DImage(image, require_static=False), image)
+ image = _Assert3DImage(image)
uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
mirror_cond = math_ops.less(uniform_random, .5)
result = control_flow_ops.cond(mirror_cond,
@@ -255,8 +276,7 @@ def random_flip_left_right(image, seed=None):
"""
with ops.name_scope(None, 'random_flip_left_right', [image]) as scope:
image = ops.convert_to_tensor(image, name='image')
- image = control_flow_ops.with_dependencies(
- _Check3DImage(image, require_static=False), image)
+ image = _Assert3DImage(image)
uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
mirror_cond = math_ops.less(uniform_random, .5)
result = control_flow_ops.cond(mirror_cond,
@@ -286,8 +306,7 @@ def flip_left_right(image):
"""
with ops.name_scope(None, 'flip_left_right', [image]) as scope:
image = ops.convert_to_tensor(image, name='image')
- image = control_flow_ops.with_dependencies(
- _Check3DImage(image, require_static=False), image)
+ image = _Assert3DImage(image)
return fix_image_flip_shape(image,
array_ops.reverse(image, [1], name=scope))
@@ -312,8 +331,7 @@ def flip_up_down(image):
"""
with ops.name_scope(None, 'flip_up_down', [image]) as scope:
image = ops.convert_to_tensor(image, name='image')
- image = control_flow_ops.with_dependencies(
- _Check3DImage(image, require_static=False), image)
+ image = _Assert3DImage(image)
return fix_image_flip_shape(image,
array_ops.reverse(image, [0], name=scope))
@@ -332,8 +350,7 @@ def rot90(image, k=1, name=None):
"""
with ops.name_scope(name, 'rot90', [image, k]) as scope:
image = ops.convert_to_tensor(image, name='image')
- image = control_flow_ops.with_dependencies(
- _Check3DImage(image, require_static=False), image)
+ image = _Assert3DImage(image)
k = ops.convert_to_tensor(k, dtype=dtypes.int32, name='k')
k.get_shape().assert_has_rank(0)
k = math_ops.mod(k, 4)
@@ -373,8 +390,7 @@ def transpose_image(image):
"""
with ops.name_scope(None, 'transpose_image', [image]) as scope:
image = ops.convert_to_tensor(image, name='image')
- image = control_flow_ops.with_dependencies(
- _Check3DImage(image, require_static=False), image)
+ image = _Assert3DImage(image)
return array_ops.transpose(image, [1, 0, 2], name=scope)
@@ -410,8 +426,7 @@ def central_crop(image, central_fraction):
if central_fraction == 1.0:
return image
- image = control_flow_ops.with_dependencies(
- _Check3DImage(image, require_static=False), image)
+ image = _Assert3DImage(image)
img_shape = array_ops.shape(image)
depth = image.get_shape()[2]
@@ -848,8 +863,7 @@ def per_image_standardization(image):
"""
with ops.name_scope(None, 'per_image_standardization', [image]) as scope:
image = ops.convert_to_tensor(image, name='image')
- image = control_flow_ops.with_dependencies(
- _Check3DImage(image, require_static=False), image)
+ image = _Assert3DImage(image)
num_pixels = math_ops.reduce_prod(array_ops.shape(image))
image = math_ops.cast(image, dtype=dtypes.float32)