diff options
author | Tayo Oguntebi <tayo@google.com> | 2018-09-26 19:45:39 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 19:49:31 -0700 |
commit | 0d5c68e30f4637329fa233df506d7b97802a5e9b (patch) | |
tree | 64dad85cf7cc1ca199aa6f5949613ded029b3574 /tensorflow/compiler | |
parent | 85258e06edf424492905fd032b02ff4d420b9da1 (diff) |
Fixes bug in tf2xla NMS implementation.
PiperOrigin-RevId: 214711381
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/tests/image_ops_test.py | 43 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/image_ops.cc | 9 |
2 files changed, 50 insertions, 2 deletions
diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index bbe746e28f..68fdb5caf4 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -724,6 +724,49 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): self.assertEqual(num_valid, 2) self.assertAllClose(indices_tf[:num_valid], [3, 0]) + def testNMS3Then1WithScoreMaxThresh(self): + # Three boxes are selected based on IOU. + # One is filtered out by score threshold. + # One is filtered out by max_output_size. + + with compat.forward_compatibility_horizon(2018, 8, 8): + boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + max_output_size = 1 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.4, dtype=np.float32) + + with self.cached_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + iou_threshold: iou_threshold_np, + score_threshold: score_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 1) + self.assertAllClose(indices_tf[:num_valid], [3]) if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index 33a73fe5fd..921b4340c0 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -355,6 +355,9 @@ class NonMaxSuppressionOp : public XlaOpKernel { OP_REQUIRES( context, output_size >= 0, errors::InvalidArgument("Need output_size >= 0, got ", output_size)); + OP_REQUIRES(context, output_size <= kint32max, + errors::InvalidArgument("Need output_size <= kint32Max, got ", + output_size)); xla::XlaOp score_thresh = context->Input("score_threshold"); xla::XlaOp iou_thresh = context->Input("iou_threshold"); @@ -439,12 +442,14 @@ class NonMaxSuppressionOp : public XlaOpKernel { xla::Broadcast(xla::ConstantR0<int32>(builder, 1), {num_boxes}), xla::Broadcast(xla::ConstantR0<int32>(builder, 0), {num_boxes})); - // num_valid is scalar. - xla::XlaOp num_valid = xla::Reduce( + // num_valid is scalar. Value should be bound by output_size. + xla::XlaOp num_valid_total = xla::Reduce( ones_included, /*init_value=*/xla::ConstantR0<int>(builder, 0), /*computation=*/CreateScalarAddComputation(xla::S32, builder), /*dimensions_to_reduce=*/{0}); + xla::XlaOp num_valid = + xla::Min(num_valid_total, xla::ConstantR0<int32>(builder, output_size)); xla::XlaOp output_tuple = TopK(scores_included, output_size); xla::XlaOp selected_indices = xla::GetTupleElement(output_tuple, 1); |