aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/image_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/image_ops_test.py')
-rw-r--r--tensorflow/python/ops/image_ops_test.py165
1 files changed, 149 insertions, 16 deletions
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 18625293e0..b67e7cc558 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -934,7 +934,7 @@ class AdjustSaturationTest(test_util.TensorFlowTestCase):
class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
- def testIdempotentLeftRight(self):
+ def testInvolutionLeftRight(self):
x_np = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
with self.test_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
@@ -942,6 +942,16 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
y_tf = y.eval()
self.assertAllEqual(y_tf, x_np)
+ def testInvolutionLeftRightWithBatch(self):
+ x_np = np.array(
+ [[[1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3]]],
+ dtype=np.uint8).reshape([2, 2, 3, 1])
+ with self.test_session(use_gpu=True):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.flip_left_right(image_ops.flip_left_right(x_tf))
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, x_np)
+
def testLeftRight(self):
x_np = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
y_np = np.array([[3, 2, 1], [3, 2, 1]], dtype=np.uint8).reshape([2, 3, 1])
@@ -953,9 +963,24 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
y_tf = y.eval()
self.assertAllEqual(y_tf, y_np)
+ def testLeftRightWithBatch(self):
+ x_np = np.array(
+ [[[1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3]]],
+ dtype=np.uint8).reshape([2, 2, 3, 1])
+ y_np = np.array(
+ [[[3, 2, 1], [3, 2, 1]], [[3, 2, 1], [3, 2, 1]]],
+ dtype=np.uint8).reshape([2, 2, 3, 1])
+
+ with self.test_session(use_gpu=True):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.flip_left_right(x_tf)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, y_np)
+
def testRandomFlipLeftRight(self):
x_np = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
y_np = np.array([[3, 2, 1], [3, 2, 1]], dtype=np.uint8).reshape([2, 3, 1])
+ seed = 42
with self.test_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
@@ -964,7 +989,7 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
count_flipped = 0
count_unflipped = 0
- for _ in range(50):
+ for _ in range(100):
y_tf = y.eval()
if y_tf[0][0] == 1:
self.assertAllEqual(y_tf, x_np)
@@ -972,10 +997,15 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
else:
self.assertAllEqual(y_tf, y_np)
count_flipped += 1
- self.assertGreaterEqual(count_flipped, 1)
- self.assertGreaterEqual(count_unflipped, 1)
- def testIdempotentUpDown(self):
+ # 100 trials
+ # Mean: 50
+ # Std Dev: ~5
+ # Six Sigma: 50 - (5 * 6) = 20
+ self.assertGreaterEqual(count_flipped, 20)
+ self.assertGreaterEqual(count_unflipped, 20)
+
+ def testInvolutionUpDown(self):
x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
with self.test_session(use_gpu=True):
@@ -984,6 +1014,17 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
y_tf = y.eval()
self.assertAllEqual(y_tf, x_np)
+ def testInvolutionUpDownWithBatch(self):
+ x_np = np.array(
+ [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
+ dtype=np.uint8).reshape([2, 2, 3, 1])
+
+ with self.test_session(use_gpu=True):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.flip_up_down(image_ops.flip_up_down(x_tf))
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, x_np)
+
def testUpDown(self):
x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
y_np = np.array([[4, 5, 6], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
@@ -995,17 +1036,31 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
y_tf = y.eval()
self.assertAllEqual(y_tf, y_np)
+ def testUpDownWithBatch(self):
+ x_np = np.array(
+ [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
+ dtype=np.uint8).reshape([2, 2, 3, 1])
+ y_np = np.array(
+ [[[4, 5, 6], [1, 2, 3]], [[10, 11, 12], [7, 8, 9]]],
+ dtype=np.uint8).reshape([2, 2, 3, 1])
+
+ with self.test_session(use_gpu=True):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.flip_up_down(x_tf)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, y_np)
+
def testRandomFlipUpDown(self):
x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
y_np = np.array([[4, 5, 6], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
with self.test_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
- y = image_ops.random_flip_up_down(x_tf)
+ y = image_ops.random_flip_up_down(x_tf, seed=42)
self.assertTrue(y.op.name.startswith("random_flip_up_down"))
count_flipped = 0
count_unflipped = 0
- for _ in range(50):
+ for _ in range(100):
y_tf = y.eval()
if y_tf[0][0] == 1:
self.assertAllEqual(y_tf, x_np)
@@ -1013,10 +1068,15 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
else:
self.assertAllEqual(y_tf, y_np)
count_flipped += 1
- self.assertGreaterEqual(count_flipped, 1)
- self.assertGreaterEqual(count_unflipped, 1)
- def testIdempotentTranspose(self):
+ # 100 trials
+ # Mean: 50
+ # Std Dev: ~5
+ # Six Sigma: 50 - (5 * 6) = 20
+ self.assertGreaterEqual(count_flipped, 20)
+ self.assertGreaterEqual(count_unflipped, 20)
+
+ def testInvolutionTranspose(self):
x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
with self.test_session(use_gpu=True):
@@ -1025,6 +1085,17 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
y_tf = y.eval()
self.assertAllEqual(y_tf, x_np)
+ def testInvolutionTransposeWithBatch(self):
+ x_np = np.array(
+ [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
+ dtype=np.uint8).reshape([2, 2, 3, 1])
+
+ with self.test_session(use_gpu=True):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.transpose_image(image_ops.transpose_image(x_tf))
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, x_np)
+
def testTranspose(self):
x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
y_np = np.array([[1, 4], [2, 5], [3, 6]], dtype=np.uint8).reshape([3, 2, 1])
@@ -1036,15 +1107,34 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
y_tf = y.eval()
self.assertAllEqual(y_tf, y_np)
+ def testTransposeWithBatch(self):
+ x_np = np.array(
+ [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
+ dtype=np.uint8).reshape([2, 2, 3, 1])
+
+ y_np = np.array(
+ [[[1, 4], [2, 5], [3, 6]], [[7, 10], [8, 11], [9, 12]]],
+ dtype=np.uint8).reshape([2, 3, 2, 1])
+
+ with self.test_session(use_gpu=True):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.transpose_image(x_tf)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, y_np)
+
def testPartialShapes(self):
p_unknown_rank = array_ops.placeholder(dtypes.uint8)
- p_unknown_dims = array_ops.placeholder(
+ p_unknown_dims_3 = array_ops.placeholder(
dtypes.uint8, shape=[None, None, None])
+ p_unknown_dims_4 = array_ops.placeholder(
+ dtypes.uint8, shape=[None, None, None, None])
p_unknown_width = array_ops.placeholder(dtypes.uint8, shape=[64, None, 3])
-
+ p_unknown_batch = array_ops.placeholder(
+ dtypes.uint8, shape=[None, 64, 64, 3])
p_wrong_rank = array_ops.placeholder(dtypes.uint8, shape=[None, None])
p_zero_dim = array_ops.placeholder(dtypes.uint8, shape=[64, 0, 3])
+ #Ops that support 3D input
for op in [
image_ops.flip_left_right, image_ops.flip_up_down,
image_ops.random_flip_left_right, image_ops.random_flip_up_down,
@@ -1052,16 +1142,34 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
]:
transformed_unknown_rank = op(p_unknown_rank)
self.assertEqual(3, transformed_unknown_rank.get_shape().ndims)
- transformed_unknown_dims = op(p_unknown_dims)
- self.assertEqual(3, transformed_unknown_dims.get_shape().ndims)
+ transformed_unknown_dims_3 = op(p_unknown_dims_3)
+ self.assertEqual(3, transformed_unknown_dims_3.get_shape().ndims)
transformed_unknown_width = op(p_unknown_width)
self.assertEqual(3, transformed_unknown_width.get_shape().ndims)
- with self.assertRaisesRegexp(ValueError, "must be three-dimensional"):
- op(p_wrong_rank)
with self.assertRaisesRegexp(ValueError, "must be > 0"):
op(p_zero_dim)
+ #Ops that support 4D input
+ for op in [
+ image_ops.flip_left_right, image_ops.flip_up_down,
+ image_ops.transpose_image, image_ops.rot90
+ ]:
+ transformed_unknown_dims_4 = op(p_unknown_dims_4)
+ self.assertEqual(4, transformed_unknown_dims_4.get_shape().ndims)
+ transformed_unknown_batch = op(p_unknown_batch)
+ self.assertEqual(4, transformed_unknown_batch.get_shape().ndims)
+ with self.assertRaisesRegexp(ValueError,
+ "must be at least three-dimensional"):
+ op(p_wrong_rank)
+
+ for op in [
+ image_ops.random_flip_left_right,
+ image_ops.random_flip_up_down,
+ ]:
+ with self.assertRaisesRegexp(ValueError, "must be three-dimensional"):
+ op(p_wrong_rank)
+
def testRot90GroupOrder(self):
image = np.arange(24, dtype=np.uint8).reshape([2, 4, 3])
with self.test_session(use_gpu=True):
@@ -1070,6 +1178,14 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
rotated = image_ops.rot90(rotated)
self.assertAllEqual(image, rotated.eval())
+ def testRot90GroupOrderWithBatch(self):
+ image = np.arange(48, dtype=np.uint8).reshape([2, 2, 4, 3])
+ with self.test_session(use_gpu=True):
+ rotated = image
+ for _ in xrange(4):
+ rotated = image_ops.rot90(rotated)
+ self.assertAllEqual(image, rotated.eval())
+
def testRot90NumpyEquivalence(self):
image = np.arange(24, dtype=np.uint8).reshape([2, 4, 3])
with self.test_session(use_gpu=True):
@@ -1079,6 +1195,15 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
y_np = np.rot90(image, k=k)
self.assertAllEqual(y_np, y_tf.eval({k_placeholder: k}))
+ def testRot90NumpyEquivalenceWithBatch(self):
+ image = np.arange(48, dtype=np.uint8).reshape([2, 2, 4, 3])
+ with self.test_session(use_gpu=True):
+ k_placeholder = array_ops.placeholder(dtypes.int32, shape=[])
+ y_tf = image_ops.rot90(image, k_placeholder)
+ for k in xrange(4):
+ y_np = np.rot90(image, k=k, axes=(1, 2))
+ self.assertAllEqual(y_np, y_tf.eval({k_placeholder: k}))
+
class RandomFlipTest(test_util.TensorFlowTestCase):
@@ -3173,6 +3298,14 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
scores = constant_op.constant([0.9])
image_ops.non_max_suppression(boxes, scores, 3, 0.5)
+ # The boxes is of shape [num_boxes, 4], and the scores is
+ # of shape [num_boxes]. So an error will thrown.
+ with self.assertRaisesRegexp(ValueError,
+ "Dimensions must be equal, but are 1 and 2"):
+ boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]])
+ scores = constant_op.constant([0.9, 0.75])
+ selected_indices = image_ops.non_max_suppression(boxes, scores, 3, 0.5)
+
# The scores should be 1D of shape [num_boxes].
with self.assertRaisesRegexp(ValueError,
"Shape must be rank 1 but is rank 2"):