aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/topk_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/topk_op.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/topk_op.cc40
1 files changed, 21 insertions, 19 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
index 1ddcb08c8e..82d4a69777 100644
--- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
@@ -41,33 +41,35 @@ class TopKOp : public XlaOpKernel {
OP_REQUIRES(context, input_shape.dims() >= 1,
errors::InvalidArgument("input must be >= 1-D, got shape ",
input_shape.DebugString()));
+ int last_dim = input_shape.dims() - 1;
+ int last_dim_size = input_shape.dim_size(last_dim);
OP_REQUIRES(
- context, input_shape.dim_size(input_shape.dims() - 1) >= k,
+ context, last_dim_size >= k,
errors::InvalidArgument("input must have at least k columns. Had ",
- input_shape.dim_size(input_shape.dims() - 1),
- ", needed ", k));
-
- OP_REQUIRES(
- context, input_shape.dims() == 1,
- errors::Unimplemented("TopK is implemented for 1-D inputs, got shape ",
- input_shape.DebugString()));
+ last_dim_size, ", needed ", k));
xla::XlaBuilder* const b = context->builder();
- if (input_shape.dim_size(0) < k) {
- k = input_shape.dim_size(0);
+ if (last_dim_size < k) {
+ k = last_dim_size;
}
const xla::XlaOp input = context->Input(0);
- xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, input_shape.dim_size(0));
- xla::XlaOp sort_result = xla::Sort(xla::Neg(input), iota_s32);
+
+ xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, last_dim_size);
+ auto input_dims = input_shape.dim_sizes();
+ std::vector<int64> broadcast_dims(input_dims.begin(), input_dims.end() - 1);
+ xla::XlaOp broadcast_s32 = xla::Broadcast(iota_s32, broadcast_dims);
+ xla::XlaOp sort_result = xla::Sort(xla::Neg(input), broadcast_s32);
+
+ std::vector<int64> start_indices(input_shape.dims(), 0);
+ std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
+ limit_indices[last_dim] = k;
+ std::vector<int64> strides(input_shape.dims(), 1);
+
xla::XlaOp values =
- xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0),
- /*start_indices=*/{0},
- /*limit_indices=*/{k},
- /*strides=*/{1}));
+ xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0), start_indices,
+ limit_indices, strides));
xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1),
- /*start_indices=*/{0},
- /*limit_indices=*/{k},
- /*strides=*/{1});
+ start_indices, limit_indices, strides);
context->SetOutput(0, values);
context->SetOutput(1, indices);
}