aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-24 12:34:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-24 12:36:56 -0700
commit4054ddab84775659c4e04b4b239e3ef93e36a2de (patch)
treeba0673f4070b37f37e75d0cf6e3388f68fdb28e7
parent43408f89b46bbbbf76df90eb30f47ecc71af0876 (diff)
Modify tf.image.central_crop to support batched-input.
Currently central_crop works on singular images with dynamic dimensions. For large image classification models, it would be nice if central_crop can be modified to support batched input. This CL makes that change. PiperOrigin-RevId: 197935606
-rw-r--r--tensorflow/python/ops/image_ops_impl.py87
-rw-r--r--tensorflow/python/ops/image_ops_test.py132
2 files changed, 164 insertions, 55 deletions
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 54e27b87df..52141ba24a 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -523,7 +523,7 @@ def transpose_image(image):
@tf_export('image.central_crop')
def central_crop(image, central_fraction):
- """Crop the central region of the image.
+ """Crop the central region of the image(s).
Remove the outer parts of an image but retain the central region of the image
along each dimension. If we specify central_fraction = 0.5, this function
@@ -536,15 +536,19 @@ def central_crop(image, central_fraction):
| | where "X" is the central 50% of the image.
--------
+ This function works on either a single image (`image` is a 3-D Tensor), or a
+ batch of images (`image` is a 4-D Tensor).
+
Args:
- image: 3-D float Tensor of shape [height, width, depth]
+ image: Either a 3-D float Tensor of shape [height, width, depth], or a 4-D
+ Tensor of shape [batch_size, height, width, depth].
central_fraction: float (0, 1], fraction of size to crop
Raises:
ValueError: if central_crop_fraction is not within (0, 1].
Returns:
- 3-D float Tensor
+ 3-D / 4-D float Tensor, as per the input.
"""
with ops.name_scope(None, 'central_crop', [image]):
image = ops.convert_to_tensor(image, name='image')
@@ -553,24 +557,75 @@ def central_crop(image, central_fraction):
if central_fraction == 1.0:
return image
- image = _Assert3DImage(image)
+ _AssertAtLeast3DImage(image)
+ rank = image.get_shape().ndims
+ if rank != 3 and rank != 4:
+ raise ValueError('`image` should either be a Tensor with rank = 3 or '
+ 'rank = 4. Had rank = {}.'.format(rank))
+
+ # Helper method to return the `idx`-th dimension of `tensor`, along with
+ # a boolean signifying if the dimension is dynamic.
+ def _get_dim(tensor, idx):
+ static_shape = tensor.get_shape()[idx].value
+ if static_shape is not None:
+ return static_shape, False
+ return array_ops.shape(tensor)[idx], True
+
+ # Get the height, width, depth (and batch size, if the image is a 4-D
+ # tensor).
+ if rank == 3:
+ img_h, dynamic_h = _get_dim(image, 0)
+ img_w, dynamic_w = _get_dim(image, 1)
+ img_d = image.get_shape()[2]
+ else:
+ img_bs = image.get_shape()[0]
+ img_h, dynamic_h = _get_dim(image, 1)
+ img_w, dynamic_w = _get_dim(image, 2)
+ img_d = image.get_shape()[3]
+
+ # Compute the bounding boxes for the crop. The type and value of the
+ # bounding boxes depend on the `image` tensor's rank and whether / not the
+ # dimensions are statically defined.
+ if dynamic_h:
+ img_hd = math_ops.to_double(img_h)
+ bbox_h_start = math_ops.to_int32((img_hd - img_hd * central_fraction) / 2)
+ else:
+ img_hd = float(img_h)
+ bbox_h_start = int((img_hd - img_hd * central_fraction) / 2)
- img_shape = array_ops.shape(image)
- depth = image.get_shape()[2]
- img_h = math_ops.to_double(img_shape[0])
- img_w = math_ops.to_double(img_shape[1])
- bbox_h_start = math_ops.to_int32((img_h - img_h * central_fraction) / 2)
- bbox_w_start = math_ops.to_int32((img_w - img_w * central_fraction) / 2)
+ if dynamic_w:
+ img_wd = math_ops.to_double(img_w)
+ bbox_w_start = math_ops.to_int32((img_wd - img_wd * central_fraction) / 2)
+ else:
+ img_wd = float(img_w)
+ bbox_w_start = int((img_wd - img_wd * central_fraction) / 2)
+
+ bbox_h_size = img_h - bbox_h_start * 2
+ bbox_w_size = img_w - bbox_w_start * 2
- bbox_h_size = img_shape[0] - bbox_h_start * 2
- bbox_w_size = img_shape[1] - bbox_w_start * 2
+ if rank == 3:
+ bbox_begin = array_ops.stack([bbox_h_start, bbox_w_start, 0])
+ bbox_size = array_ops.stack([bbox_h_size, bbox_w_size, -1])
+ else:
+ bbox_begin = array_ops.stack([0, bbox_h_start, bbox_w_start, 0])
+ bbox_size = array_ops.stack([-1, bbox_h_size, bbox_w_size, -1])
- bbox_begin = array_ops.stack([bbox_h_start, bbox_w_start, 0])
- bbox_size = array_ops.stack([bbox_h_size, bbox_w_size, -1])
image = array_ops.slice(image, bbox_begin, bbox_size)
- # The first two dimensions are dynamic and unknown.
- image.set_shape([None, None, depth])
+ # Reshape the `image` tensor to the desired size.
+ if rank == 3:
+ image.set_shape([
+ None if dynamic_h else bbox_h_size,
+ None if dynamic_w else bbox_w_size,
+ img_d
+ ])
+ else:
+ image.set_shape([
+ img_bs,
+ None if dynamic_h else bbox_h_size,
+ None if dynamic_w else bbox_w_size,
+ img_d
+ ])
return image
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index c437c12c27..72c889a2e6 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -1585,14 +1585,16 @@ class CentralCropTest(test_util.TensorFlowTestCase):
self.assertEqual(y.get_shape().as_list(), post_shape)
def testNoOp(self):
- x_shape = [13, 9, 3]
- x_np = np.ones(x_shape, dtype=np.float32)
- with self.test_session(use_gpu=True):
- x = constant_op.constant(x_np, shape=x_shape)
- y = image_ops.central_crop(x, 1.0)
- y_tf = y.eval()
- self.assertAllEqual(y_tf, x_np)
- self.assertEqual(y.op.name, x.op.name)
+ x_shapes = [[13, 9, 3], [5, 13, 9, 3]]
+ for x_shape in x_shapes:
+ x_np = np.ones(x_shape, dtype=np.float32)
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ x = constant_op.constant(x_np, shape=x_shape)
+ y = image_ops.central_crop(x, 1.0)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, x_np)
+ self.assertEqual(y.op.name, x.op.name)
def testCropping(self):
x_shape = [4, 8, 1]
@@ -1601,6 +1603,23 @@ class CentralCropTest(test_util.TensorFlowTestCase):
[1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 3, 4, 5, 6, 7, 8]],
dtype=np.int32).reshape(x_shape)
y_np = np.array([[3, 4, 5, 6], [3, 4, 5, 6]]).reshape([2, 4, 1])
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ x = constant_op.constant(x_np, shape=x_shape)
+ y = image_ops.central_crop(x, 0.5)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, y_np)
+ self.assertAllEqual(y_tf.shape, y_np.shape)
+
+ x_shape = [2, 4, 8, 1]
+ x_np = np.array(
+ [[1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 3, 4, 5, 6, 7, 8],
+ [1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 3, 4, 5, 6, 7, 8],
+ [8, 7, 6, 5, 4, 3, 2, 1], [8, 7, 6, 5, 4, 3, 2, 1],
+ [8, 7, 6, 5, 4, 3, 2, 1], [8, 7, 6, 5, 4, 3, 2, 1]],
+ dtype=np.int32).reshape(x_shape)
+ y_np = np.array([[[3, 4, 5, 6], [3, 4, 5, 6]],
+ [[6, 5, 4, 3], [6, 5, 4, 3]]]).reshape([2, 2, 4, 1])
with self.test_session(use_gpu=True):
x = constant_op.constant(x_np, shape=x_shape)
y = image_ops.central_crop(x, 0.5)
@@ -1610,52 +1629,87 @@ class CentralCropTest(test_util.TensorFlowTestCase):
def testCropping2(self):
# Test case for 10315
- x_shape = [240, 320, 3]
- x_np = np.zeros(x_shape, dtype=np.int32)
- y_np = np.zeros([80, 106, 3], dtype=np.int32)
- with self.test_session(use_gpu=True):
- x = array_ops.placeholder(shape=x_shape, dtype=dtypes.int32)
- y = image_ops.central_crop(x, 0.33)
- y_tf = y.eval(feed_dict={x: x_np})
- self.assertAllEqual(y_tf, y_np)
- self.assertAllEqual(y_tf.shape, y_np.shape)
+ x_shapes = [[240, 320, 3], [5, 240, 320, 3]]
+ expected_y_shapes = [[80, 106, 3], [5, 80, 106, 3]]
+
+ for x_shape, y_shape in zip(x_shapes, expected_y_shapes):
+ x_np = np.zeros(x_shape, dtype=np.int32)
+ y_np = np.zeros(y_shape, dtype=np.int32)
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ x = array_ops.placeholder(shape=x_shape, dtype=dtypes.int32)
+ y = image_ops.central_crop(x, 0.33)
+ y_tf = y.eval(feed_dict={x: x_np})
+ self.assertAllEqual(y_tf, y_np)
+ self.assertAllEqual(y_tf.shape, y_np.shape)
def testShapeInference(self):
- # Test no-op fraction=1.0
+ # Test no-op fraction=1.0, with 3-D tensors.
self._assertShapeInference([50, 60, 3], 1.0, [50, 60, 3])
self._assertShapeInference([None, 60, 3], 1.0, [None, 60, 3])
self._assertShapeInference([50, None, 3], 1.0, [50, None, 3])
self._assertShapeInference([None, None, 3], 1.0, [None, None, 3])
self._assertShapeInference([50, 60, None], 1.0, [50, 60, None])
self._assertShapeInference([None, None, None], 1.0, [None, None, None])
- self._assertShapeInference(None, 1.0, None)
- # TODO(toddw): Currently central_crop() doesn't infer the result shape even
- # when it's possible. If we change it to do so, we can test as follows:
- #
- # self._assertShapeInference([50, 60, 3], 0.5, [25, 30, 3])
- # self._assertShapeInference([None, 60, 3], 0.5, [None, 30, 3])
- # self._assertShapeInference([50, None, 3], 0.5, [25, None, 3])
- # self._assertShapeInference([None, None, 3], 0.5, [None, None, 3])
- # self._assertShapeInference([50, 60, None], 0.5, [25, 30, None])
- # self._assertShapeInference([None, None, None], 0.5, [None, None, None])
- # self._assertShapeInference(None, 0.5, None)
- def testError(self):
+ # Test no-op fraction=0.5, with 3-D tensors.
+ self._assertShapeInference([50, 60, 3], 0.5, [26, 30, 3])
+ self._assertShapeInference([None, 60, 3], 0.5, [None, 30, 3])
+ self._assertShapeInference([50, None, 3], 0.5, [26, None, 3])
+ self._assertShapeInference([None, None, 3], 0.5, [None, None, 3])
+ self._assertShapeInference([50, 60, None], 0.5, [26, 30, None])
+ self._assertShapeInference([None, None, None], 0.5, [None, None, None])
+
+ # Test no-op fraction=1.0, with 4-D tensors.
+ self._assertShapeInference([5, 50, 60, 3], 1.0, [5, 50, 60, 3])
+ self._assertShapeInference([5, None, 60, 3], 1.0, [5, None, 60, 3])
+ self._assertShapeInference([5, 50, None, 3], 1.0, [5, 50, None, 3])
+ self._assertShapeInference([5, None, None, 3], 1.0, [5, None, None, 3])
+ self._assertShapeInference([5, 50, 60, None], 1.0, [5, 50, 60, None])
+ self._assertShapeInference([5, None, None, None], 1.0,
+ [5, None, None, None])
+ self._assertShapeInference([None, None, None, None], 1.0,
+ [None, None, None, None])
+
+ # Test no-op fraction=0.5, with 4-D tensors.
+ self._assertShapeInference([5, 50, 60, 3], 0.5, [5, 26, 30, 3])
+ self._assertShapeInference([5, None, 60, 3], 0.5, [5, None, 30, 3])
+ self._assertShapeInference([5, 50, None, 3], 0.5, [5, 26, None, 3])
+ self._assertShapeInference([5, None, None, 3], 0.5, [5, None, None, 3])
+ self._assertShapeInference([5, 50, 60, None], 0.5, [5, 26, 30, None])
+ self._assertShapeInference([5, None, None, None], 0.5,
+ [5, None, None, None])
+ self._assertShapeInference([None, None, None, None], 0.5,
+ [None, None, None, None])
+
+ def testErrorOnInvalidCentralCropFractionValues(self):
x_shape = [13, 9, 3]
x_np = np.ones(x_shape, dtype=np.float32)
- with self.test_session(use_gpu=True):
- x = constant_op.constant(x_np, shape=x_shape)
- with self.assertRaises(ValueError):
- _ = image_ops.central_crop(x, 0.0)
- with self.assertRaises(ValueError):
- _ = image_ops.central_crop(x, 1.01)
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ x = constant_op.constant(x_np, shape=x_shape)
+ with self.assertRaises(ValueError):
+ _ = image_ops.central_crop(x, 0.0)
+ with self.assertRaises(ValueError):
+ _ = image_ops.central_crop(x, 1.01)
+
+ def testErrorOnInvalidShapes(self):
+ x_shapes = [None, [], [3], [3, 9], [3, 9, 3, 9, 3]]
+ for x_shape in x_shapes:
+ x_np = np.ones(x_shape, dtype=np.float32)
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ x = constant_op.constant(x_np, shape=x_shape)
+ with self.assertRaises(ValueError):
+ _ = image_ops.central_crop(x, 0.5)
def testNameScope(self):
x_shape = [13, 9, 3]
x_np = np.ones(x_shape, dtype=np.float32)
- with self.test_session(use_gpu=True):
- y = image_ops.central_crop(x_np, 1.0)
- self.assertTrue(y.op.name.startswith("central_crop"))
+ for use_gpu in [True, False]:
+ with self.test_session(use_gpu=use_gpu):
+ y = image_ops.central_crop(x_np, 1.0)
+ self.assertTrue(y.op.name.startswith("central_crop"))
class PadToBoundingBoxTest(test_util.TensorFlowTestCase):