diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/topk_op.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/topk_op.cc | 40 |
1 files changed, 21 insertions, 19 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index 1ddcb08c8e..82d4a69777 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -41,33 +41,35 @@ class TopKOp : public XlaOpKernel { OP_REQUIRES(context, input_shape.dims() >= 1, errors::InvalidArgument("input must be >= 1-D, got shape ", input_shape.DebugString())); + int last_dim = input_shape.dims() - 1; + int last_dim_size = input_shape.dim_size(last_dim); OP_REQUIRES( - context, input_shape.dim_size(input_shape.dims() - 1) >= k, + context, last_dim_size >= k, errors::InvalidArgument("input must have at least k columns. Had ", - input_shape.dim_size(input_shape.dims() - 1), - ", needed ", k)); - - OP_REQUIRES( - context, input_shape.dims() == 1, - errors::Unimplemented("TopK is implemented for 1-D inputs, got shape ", - input_shape.DebugString())); + last_dim_size, ", needed ", k)); xla::XlaBuilder* const b = context->builder(); - if (input_shape.dim_size(0) < k) { - k = input_shape.dim_size(0); + if (last_dim_size < k) { + k = last_dim_size; } const xla::XlaOp input = context->Input(0); - xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, input_shape.dim_size(0)); - xla::XlaOp sort_result = xla::Sort(xla::Neg(input), iota_s32); + + xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, last_dim_size); + auto input_dims = input_shape.dim_sizes(); + std::vector<int64> broadcast_dims(input_dims.begin(), input_dims.end() - 1); + xla::XlaOp broadcast_s32 = xla::Broadcast(iota_s32, broadcast_dims); + xla::XlaOp sort_result = xla::Sort(xla::Neg(input), broadcast_s32); + + std::vector<int64> start_indices(input_shape.dims(), 0); + std::vector<int64> limit_indices(input_dims.begin(), input_dims.end()); + limit_indices[last_dim] = k; + std::vector<int64> strides(input_shape.dims(), 1); + xla::XlaOp values = - xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0), - /*start_indices=*/{0}, - /*limit_indices=*/{k}, - /*strides=*/{1})); + xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0), start_indices, + limit_indices, strides)); xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1), - /*start_indices=*/{0}, - /*limit_indices=*/{k}, - /*strides=*/{1}); + start_indices, limit_indices, strides); context->SetOutput(0, values); context->SetOutput(1, indices); } |