aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Tayo Oguntebi <tayo@google.com>2018-09-26 19:45:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 19:49:31 -0700
commit0d5c68e30f4637329fa233df506d7b97802a5e9b (patch)
tree64dad85cf7cc1ca199aa6f5949613ded029b3574 /tensorflow/compiler
parent85258e06edf424492905fd032b02ff4d420b9da1 (diff)
Fixes bug in tf2xla NMS implementation.
PiperOrigin-RevId: 214711381
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/tests/image_ops_test.py43
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_ops.cc9
2 files changed, 50 insertions, 2 deletions
diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py
index bbe746e28f..68fdb5caf4 100644
--- a/tensorflow/compiler/tests/image_ops_test.py
+++ b/tensorflow/compiler/tests/image_ops_test.py
@@ -724,6 +724,49 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
self.assertEqual(num_valid, 2)
self.assertAllClose(indices_tf[:num_valid], [3, 0])
+ def testNMS3Then1WithScoreMaxThresh(self):
+ # Three boxes are selected based on IOU.
+ # One is filtered out by score threshold.
+ # One is filtered out by max_output_size.
+
+ with compat.forward_compatibility_horizon(2018, 8, 8):
+ boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
+ [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
+ boxes_np = np.array(boxes_data, dtype=np.float32)
+
+ scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3]
+ scores_np = np.array(scores_data, dtype=np.float32)
+ max_output_size = 1
+ iou_threshold_np = np.array(0.5, dtype=np.float32)
+ score_threshold_np = np.array(0.4, dtype=np.float32)
+
+ with self.cached_session() as sess:
+ boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
+ scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
+ iou_threshold = array_ops.placeholder(iou_threshold_np.dtype,
+ iou_threshold_np.shape)
+ score_threshold = array_ops.placeholder(score_threshold_np.dtype,
+ score_threshold_np.shape)
+ with self.test_scope():
+ selected_indices = image_ops.non_max_suppression_padded(
+ boxes=boxes,
+ scores=scores,
+ max_output_size=max_output_size,
+ iou_threshold=iou_threshold,
+ score_threshold=score_threshold,
+ pad_to_max_output_size=True)
+ inputs_feed = {
+ boxes: boxes_np,
+ scores: scores_np,
+ iou_threshold: iou_threshold_np,
+ score_threshold: score_threshold_np
+ }
+ (indices_tf, num_valid) = sess.run(
+ selected_indices, feed_dict=inputs_feed)
+
+ self.assertEqual(indices_tf.size, max_output_size)
+ self.assertEqual(num_valid, 1)
+ self.assertAllClose(indices_tf[:num_valid], [3])
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
index 33a73fe5fd..921b4340c0 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
@@ -355,6 +355,9 @@ class NonMaxSuppressionOp : public XlaOpKernel {
OP_REQUIRES(
context, output_size >= 0,
errors::InvalidArgument("Need output_size >= 0, got ", output_size));
+ OP_REQUIRES(context, output_size <= kint32max,
+ errors::InvalidArgument("Need output_size <= kint32Max, got ",
+ output_size));
xla::XlaOp score_thresh = context->Input("score_threshold");
xla::XlaOp iou_thresh = context->Input("iou_threshold");
@@ -439,12 +442,14 @@ class NonMaxSuppressionOp : public XlaOpKernel {
xla::Broadcast(xla::ConstantR0<int32>(builder, 1), {num_boxes}),
xla::Broadcast(xla::ConstantR0<int32>(builder, 0), {num_boxes}));
- // num_valid is scalar.
- xla::XlaOp num_valid = xla::Reduce(
+ // num_valid is scalar. Value should be bound by output_size.
+ xla::XlaOp num_valid_total = xla::Reduce(
ones_included,
/*init_value=*/xla::ConstantR0<int>(builder, 0),
/*computation=*/CreateScalarAddComputation(xla::S32, builder),
/*dimensions_to_reduce=*/{0});
+ xla::XlaOp num_valid =
+ xla::Min(num_valid_total, xla::ConstantR0<int32>(builder, output_size));
xla::XlaOp output_tuple = TopK(scores_included, output_size);
xla::XlaOp selected_indices = xla::GetTupleElement(output_tuple, 1);