aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/image_ops_test.py
diff options
context:
space:
mode:
authorGravatar Tayo Oguntebi <tayo@google.com>2018-07-30 17:13:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-30 17:16:55 -0700
commit8566d9e6fa7dbe3660339befe8b0a3344d24ef2b (patch)
tree54143734f9b19a9ead1a23784f75dcd2009d1452 /tensorflow/compiler/tests/image_ops_test.py
parent8fee2e4b7c915d952332dc8cc9be7cfefea35162 (diff)
Adds a NonMaxSuppressionV4 op, with a corresponding TF2XLA implementation.
PiperOrigin-RevId: 206673787
Diffstat (limited to 'tensorflow/compiler/tests/image_ops_test.py')
-rw-r--r--tensorflow/compiler/tests/image_ops_test.py136
1 files changed, 136 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py
index 8b01ef96db..bf986ade06 100644
--- a/tensorflow/compiler/tests/image_ops_test.py
+++ b/tensorflow/compiler/tests/image_ops_test.py
@@ -26,6 +26,7 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.compiler.tests import xla_test
+from tensorflow.python.compat import compat
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -579,5 +580,140 @@ class ResizeBilinearTest(xla_test.XLATestCase):
large_tolerance=True)
+class NonMaxSuppressionTest(xla_test.XLATestCase):
+
+ def testNMS128From1024(self):
+ # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
+ if self.device in ["XLA_CPU", "XLA_GPU"]:
+ return
+
+ with compat.forward_compatibility_horizon(2018, 8, 8):
+ num_boxes = 1024
+ boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4")
+ scores_np = np.random.normal(0.5, 0.1, (num_boxes,)).astype("f4")
+
+ max_output_size = 128
+ iou_threshold_np = np.array(0.5, dtype=np.float32)
+ score_threshold_np = np.array(0.0, dtype=np.float32)
+
+ with self.test_session() as sess:
+ boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
+ scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
+ iou_threshold = array_ops.placeholder(iou_threshold_np.dtype,
+ iou_threshold_np.shape)
+ score_threshold = array_ops.placeholder(score_threshold_np.dtype,
+ score_threshold_np.shape)
+ with self.test_scope():
+ selected_indices = image_ops.non_max_suppression_padded(
+ boxes=boxes,
+ scores=scores,
+ max_output_size=max_output_size,
+ iou_threshold=iou_threshold,
+ score_threshold=score_threshold,
+ pad_to_max_output_size=True)
+ inputs_feed = {
+ boxes: boxes_np,
+ scores: scores_np,
+ score_threshold: score_threshold_np,
+ iou_threshold: iou_threshold_np
+ }
+ (indices_tf, _) = sess.run(selected_indices, feed_dict=inputs_feed)
+
+ self.assertEqual(indices_tf.size, max_output_size)
+
+ def testNMS3From6Boxes(self):
+ # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
+ if self.device in ["XLA_CPU", "XLA_GPU"]:
+ return
+
+ with compat.forward_compatibility_horizon(2018, 8, 8):
+ # Three boxes are selected based on IOU.
+ boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
+ [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
+ boxes_np = np.array(boxes_data, dtype=np.float32)
+
+ scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3]
+ scores_np = np.array(scores_data, dtype=np.float32)
+
+ max_output_size = 3
+ iou_threshold_np = np.array(0.5, dtype=np.float32)
+ score_threshold_np = np.array(0.0, dtype=np.float32)
+
+ with self.test_session() as sess:
+ boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
+ scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
+ iou_threshold = array_ops.placeholder(iou_threshold_np.dtype,
+ iou_threshold_np.shape)
+ score_threshold = array_ops.placeholder(score_threshold_np.dtype,
+ score_threshold_np.shape)
+ with self.test_scope():
+ selected_indices = image_ops.non_max_suppression_padded(
+ boxes=boxes,
+ scores=scores,
+ max_output_size=max_output_size,
+ iou_threshold=iou_threshold,
+ score_threshold=score_threshold,
+ pad_to_max_output_size=True)
+ inputs_feed = {
+ boxes: boxes_np,
+ scores: scores_np,
+ score_threshold: score_threshold_np,
+ iou_threshold: iou_threshold_np
+ }
+ (indices_tf, num_valid) = sess.run(
+ selected_indices, feed_dict=inputs_feed)
+
+ self.assertEqual(indices_tf.size, max_output_size)
+ self.assertEqual(num_valid, 3)
+ self.assertAllClose(indices_tf[:num_valid], [3, 0, 5])
+
+ def testNMS3Then2WithScoreThresh(self):
+ # Three boxes are selected based on IOU.
+ # One is filtered out by score threshold.
+
+ # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
+ if self.device in ["XLA_CPU", "XLA_GPU"]:
+ return
+
+ with compat.forward_compatibility_horizon(2018, 8, 8):
+ boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
+ [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
+ boxes_np = np.array(boxes_data, dtype=np.float32)
+
+ scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3]
+ scores_np = np.array(scores_data, dtype=np.float32)
+ max_output_size = 3
+ iou_threshold_np = np.array(0.5, dtype=np.float32)
+ score_threshold_np = np.array(0.4, dtype=np.float32)
+
+ with self.test_session() as sess:
+ boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
+ scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
+ iou_threshold = array_ops.placeholder(iou_threshold_np.dtype,
+ iou_threshold_np.shape)
+ score_threshold = array_ops.placeholder(score_threshold_np.dtype,
+ score_threshold_np.shape)
+ with self.test_scope():
+ selected_indices = image_ops.non_max_suppression_padded(
+ boxes=boxes,
+ scores=scores,
+ max_output_size=max_output_size,
+ iou_threshold=iou_threshold,
+ score_threshold=score_threshold,
+ pad_to_max_output_size=True)
+ inputs_feed = {
+ boxes: boxes_np,
+ scores: scores_np,
+ iou_threshold: iou_threshold_np,
+ score_threshold: score_threshold_np
+ }
+ (indices_tf, num_valid) = sess.run(
+ selected_indices, feed_dict=inputs_feed)
+
+ self.assertEqual(indices_tf.size, max_output_size)
+ self.assertEqual(num_valid, 2)
+ self.assertAllClose(indices_tf[:num_valid], [3, 0])
+
+
if __name__ == "__main__":
test.main()