aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/image_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/image_ops_test.py')
-rw-r--r--tensorflow/python/ops/image_ops_test.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 0e4193e23b..2c61bb232a 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -3658,6 +3658,41 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
image_ops.non_max_suppression(boxes, scores, 3, [[0.5]])
+class NonMaxSuppressionPaddedTest(test_util.TensorFlowTestCase):
+
+ def testSelectFromThreeClusters(self):
+ boxes_np = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
+ [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
+ scores_np = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3]
+ max_output_size_np = 5
+ iou_threshold_np = 0.5
+ boxes = constant_op.constant(boxes_np)
+ scores = constant_op.constant(scores_np)
+ max_output_size = constant_op.constant(max_output_size_np)
+ iou_threshold = constant_op.constant(iou_threshold_np)
+ selected_indices_padded, num_valid_padded = \
+ image_ops.non_max_suppression_padded(
+ boxes,
+ scores,
+ max_output_size,
+ iou_threshold,
+ pad_to_max_output_size=True)
+ selected_indices, num_valid = image_ops.non_max_suppression_padded(
+ boxes,
+ scores,
+ max_output_size,
+ iou_threshold,
+ pad_to_max_output_size=False)
+ # The output shape of the padded operation must be fully defined.
+ self.assertEqual(selected_indices_padded.shape.is_fully_defined(), True)
+ self.assertEqual(selected_indices.shape.is_fully_defined(), False)
+ with self.test_session():
+ self.assertAllClose(selected_indices_padded.eval(), [3, 0, 5, 0, 0])
+ self.assertEqual(num_valid_padded.eval(), 3)
+ self.assertAllClose(selected_indices.eval(), [3, 0, 5])
+ self.assertEqual(num_valid.eval(), 3)
+
+
class VerifyCompatibleImageShapesTest(test_util.TensorFlowTestCase):
"""Tests utility function used by ssim() and psnr()."""