aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/topk_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/topk_op.cc')
-rw-r--r--tensorflow/core/kernels/topk_op.cc71
1 files changed, 71 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/topk_op.cc b/tensorflow/core/kernels/topk_op.cc
new file mode 100644
index 0000000000..79b5d4d07e
--- /dev/null
+++ b/tensorflow/core/kernels/topk_op.cc
@@ -0,0 +1,71 @@
+// See docs in ../ops/nn_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/lib/gtl/top_n.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+template <typename T>
+class TopK : public OpKernel {
+ public:
+ explicit TopK(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const auto& input_in = context->input(0);
+ OP_REQUIRES(context, input_in.dims() == 2,
+ errors::InvalidArgument("input must be 2-dimensional"));
+ OP_REQUIRES(context, input_in.dim_size(1) >= k_,
+ errors::InvalidArgument("input must have at least k columns"));
+
+ const auto& input = input_in.matrix<T>();
+
+ const auto num_rows = input_in.dim_size(0); // generally batch_size
+ const auto num_cols = input_in.dim_size(1);
+
+ Tensor* values_out = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 0, TensorShape({num_rows, k_}), &values_out));
+ Tensor* indices_out = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 1, TensorShape({num_rows, k_}), &indices_out));
+ auto values = values_out->matrix<T>();
+ auto indices = indices_out->matrix<int32>();
+
+ gtl::TopN<std::pair<T, int32>> filter(k_);
+
+ for (int r = 0; r < num_rows; r++) {
+ for (int32 c = 0; c < num_cols; ++c) {
+ // The second element is the negated index, so that lower-index elements
+ // are considered larger than higher-index elements in case of ties.
+ filter.push(std::make_pair(input(r, c), -c));
+ }
+
+ std::unique_ptr<std::vector<std::pair<T, int32>>> top_k(filter.Extract());
+ for (int32 i = 0; i < k_; ++i) {
+ values(r, i) = (*top_k)[i].first;
+ indices(r, i) = -(*top_k)[i].second;
+ }
+ filter.Reset();
+ }
+ }
+
+ private:
+ int k_;
+};
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("TopK").Device(DEVICE_CPU).TypeConstraint<type>("T"), TopK<type>)
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+} // namespace tensorflow