aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla
diff options
context:
space:
mode:
authorGravatar Tayo Oguntebi <tayo@google.com>2018-09-26 19:45:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 19:49:31 -0700
commit0d5c68e30f4637329fa233df506d7b97802a5e9b (patch)
tree64dad85cf7cc1ca199aa6f5949613ded029b3574 /tensorflow/compiler/tf2xla
parent85258e06edf424492905fd032b02ff4d420b9da1 (diff)
Fixes bug in tf2xla NMS implementation.
PiperOrigin-RevId: 214711381
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_ops.cc9
1 files changed, 7 insertions, 2 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
index 33a73fe5fd..921b4340c0 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
@@ -355,6 +355,9 @@ class NonMaxSuppressionOp : public XlaOpKernel {
OP_REQUIRES(
context, output_size >= 0,
errors::InvalidArgument("Need output_size >= 0, got ", output_size));
+ OP_REQUIRES(context, output_size <= kint32max,
+ errors::InvalidArgument("Need output_size <= kint32Max, got ",
+ output_size));
xla::XlaOp score_thresh = context->Input("score_threshold");
xla::XlaOp iou_thresh = context->Input("iou_threshold");
@@ -439,12 +442,14 @@ class NonMaxSuppressionOp : public XlaOpKernel {
xla::Broadcast(xla::ConstantR0<int32>(builder, 1), {num_boxes}),
xla::Broadcast(xla::ConstantR0<int32>(builder, 0), {num_boxes}));
- // num_valid is scalar.
- xla::XlaOp num_valid = xla::Reduce(
+ // num_valid is scalar. Value should be bound by output_size.
+ xla::XlaOp num_valid_total = xla::Reduce(
ones_included,
/*init_value=*/xla::ConstantR0<int>(builder, 0),
/*computation=*/CreateScalarAddComputation(xla::S32, builder),
/*dimensions_to_reduce=*/{0});
+ xla::XlaOp num_valid =
+ xla::Min(num_valid_total, xla::ConstantR0<int32>(builder, output_size));
xla::XlaOp output_tuple = TopK(scores_included, output_size);
xla::XlaOp selected_indices = xla::GetTupleElement(output_tuple, 1);